ab3b02a48e
Populate the repo with the Vault/OpenBao dynamic secrets engine that mints LiteLLM virtual keys scoped by model, spending limit, and lease TTL. - Secrets backend: config, roles, creds paths and a revocable litellm_key type - LiteLLM API client (generate/update/delete/info) with master-key auth - Unit tests (mock LiteLLM) and a docker-compose e2e against both Vault and OpenBao proving the same binary works on each - Makefile, woodpecker CI (build/test/pre-commit), pre-commit config
217 lines
5.1 KiB
Go
217 lines
5.1 KiB
Go
package litellm
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strconv"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
)
|
|
|
|
// getTestBackend returns a configured backend backed by in-memory storage.
|
|
func getTestBackend(t *testing.T) (*litellmBackend, logical.Storage) {
|
|
t.Helper()
|
|
|
|
config := logical.TestBackendConfig()
|
|
config.StorageView = &logical.InmemStorage{}
|
|
config.System = logical.TestSystemView()
|
|
|
|
b, err := Factory(context.Background(), config)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error creating backend: %v", err)
|
|
}
|
|
return b.(*litellmBackend), config.StorageView
|
|
}
|
|
|
|
// mockLiteLLM is an in-memory fake of the LiteLLM key-management API.
|
|
type mockLiteLLM struct {
|
|
server *httptest.Server
|
|
|
|
mu sync.Mutex
|
|
keys map[string]mockKey // key value -> key
|
|
counter int
|
|
masterKey string
|
|
|
|
// generateErr, when set, makes /key/generate return 500.
|
|
generateErr bool
|
|
lastRequest map[string]interface{}
|
|
}
|
|
|
|
type mockKey struct {
|
|
Alias string
|
|
Models []string
|
|
MaxBudget *float64
|
|
Duration string
|
|
}
|
|
|
|
func newMockLiteLLM(t *testing.T) *mockLiteLLM {
|
|
t.Helper()
|
|
m := &mockLiteLLM{
|
|
keys: make(map[string]mockKey),
|
|
masterKey: "sk-master-1234",
|
|
}
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/key/generate", m.handleGenerate)
|
|
mux.HandleFunc("/key/delete", m.handleDelete)
|
|
mux.HandleFunc("/key/update", m.handleUpdate)
|
|
mux.HandleFunc("/key/info", m.handleInfo)
|
|
|
|
m.server = httptest.NewServer(m.authMiddleware(mux))
|
|
t.Cleanup(m.server.Close)
|
|
return m
|
|
}
|
|
|
|
func (m *mockLiteLLM) authMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Header.Get("Authorization") != "Bearer "+m.masterKey {
|
|
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (m *mockLiteLLM) keyCount() int {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return len(m.keys)
|
|
}
|
|
|
|
func (m *mockLiteLLM) handleGenerate(w http.ResponseWriter, r *http.Request) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if m.generateErr {
|
|
http.Error(w, `{"error":"boom"}`, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
var body map[string]interface{}
|
|
_ = json.NewDecoder(r.Body).Decode(&body)
|
|
m.lastRequest = body
|
|
|
|
m.counter++
|
|
value := "sk-generated-" + strconv.Itoa(m.counter)
|
|
|
|
mk := mockKey{}
|
|
if alias, ok := body["key_alias"].(string); ok {
|
|
mk.Alias = alias
|
|
}
|
|
if models, ok := body["models"].([]interface{}); ok {
|
|
for _, mdl := range models {
|
|
if s, ok := mdl.(string); ok {
|
|
mk.Models = append(mk.Models, s)
|
|
}
|
|
}
|
|
}
|
|
if budget, ok := body["max_budget"].(float64); ok {
|
|
mk.MaxBudget = &budget
|
|
}
|
|
if dur, ok := body["duration"].(string); ok {
|
|
mk.Duration = dur
|
|
}
|
|
m.keys[value] = mk
|
|
|
|
writeJSON(w, map[string]interface{}{
|
|
"key": value,
|
|
"key_name": value,
|
|
"token_id": "tok-" + strconv.Itoa(m.counter),
|
|
"models": mk.Models,
|
|
"max_budget": mk.MaxBudget,
|
|
})
|
|
}
|
|
|
|
func (m *mockLiteLLM) handleDelete(w http.ResponseWriter, r *http.Request) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
var body struct {
|
|
Keys []string `json:"keys"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
|
http.Error(w, `{"error":"bad request"}`, http.StatusBadRequest)
|
|
return
|
|
}
|
|
deleted := 0
|
|
for _, k := range body.Keys {
|
|
if _, ok := m.keys[k]; ok {
|
|
delete(m.keys, k)
|
|
deleted++
|
|
}
|
|
}
|
|
writeJSON(w, map[string]interface{}{"deleted_keys": deleted})
|
|
}
|
|
|
|
func (m *mockLiteLLM) handleUpdate(w http.ResponseWriter, r *http.Request) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
var body map[string]interface{}
|
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
|
http.Error(w, `{"error":"bad request"}`, http.StatusBadRequest)
|
|
return
|
|
}
|
|
key, _ := body["key"].(string)
|
|
mk, ok := m.keys[key]
|
|
if !ok {
|
|
http.Error(w, `{"error":"not found"}`, http.StatusNotFound)
|
|
return
|
|
}
|
|
if dur, ok := body["duration"].(string); ok {
|
|
mk.Duration = dur
|
|
}
|
|
if budget, ok := body["max_budget"].(float64); ok {
|
|
mk.MaxBudget = &budget
|
|
}
|
|
m.keys[key] = mk
|
|
writeJSON(w, map[string]interface{}{"key": key})
|
|
}
|
|
|
|
func (m *mockLiteLLM) handleInfo(w http.ResponseWriter, r *http.Request) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
key := r.URL.Query().Get("key")
|
|
mk, ok := m.keys[key]
|
|
if !ok {
|
|
http.Error(w, `{"error":"not found"}`, http.StatusNotFound)
|
|
return
|
|
}
|
|
writeJSON(w, map[string]interface{}{
|
|
"key": key,
|
|
"info": map[string]interface{}{
|
|
"models": mk.Models,
|
|
"max_budget": mk.MaxBudget,
|
|
"key_name": mk.Alias,
|
|
},
|
|
})
|
|
}
|
|
|
|
func writeJSON(w http.ResponseWriter, v interface{}) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
_ = json.NewEncoder(w).Encode(v)
|
|
}
|
|
|
|
// writeTestConfig stores a config pointing at the given base URL.
|
|
func writeTestConfig(t *testing.T, b *litellmBackend, s logical.Storage, baseURL, masterKey string) {
|
|
t.Helper()
|
|
resp, err := b.HandleRequest(context.Background(), &logical.Request{
|
|
Operation: logical.CreateOperation,
|
|
Path: "config",
|
|
Storage: s,
|
|
Data: map[string]interface{}{
|
|
"base_url": baseURL,
|
|
"master_key": masterKey,
|
|
},
|
|
})
|
|
if err != nil || (resp != nil && resp.IsError()) {
|
|
t.Fatalf("failed to write config: err=%v resp=%v", err, resp)
|
|
}
|
|
}
|