Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions containers/api-proxy/management.js
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const metrics = require('./metrics');
* @property {() => import('./rate-limiter').RateLimiter} getLimiter
* @property {string|undefined} httpsProxy - Value of HTTPS_PROXY env var at startup
* @property {() => object|null} getModelAliases - Returns parsed MODEL_ALIASES (or null)
* @property {() => { enabled: boolean, strategy: string }} getModelFallback - Returns fallback config
* @property {() => object} getEffectiveTokenUsage - Returns effective token usage summary
* @property {() => object} getMaxRunsUsage - Returns max-runs usage summary
*/
Expand All @@ -46,6 +47,7 @@ function createManagementHandlers(deps) {
getLimiter,
httpsProxy,
getModelAliases,
getModelFallback,
getEffectiveTokenUsage,
getMaxRunsUsage,
} = deps;
Expand Down Expand Up @@ -95,6 +97,7 @@ function createManagementHandlers(deps) {
}),
models_fetch_complete: isModelFetchComplete(),
model_aliases: modelAliases ? modelAliases.models : null,
model_fallback: getModelFallback(),
effective_tokens: getEffectiveTokenUsage(),
runs: getMaxRunsUsage(),
};
Expand Down
42 changes: 42 additions & 0 deletions containers/api-proxy/model-discovery.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,46 @@ const MODELS_LOG_DIR = process.env.AWF_API_PROXY_LOG_DIR || '/var/log/api-proxy'

const GEMINI_MODEL_NAME_PREFIX = 'models/';

function getModelCapabilityTier(provider, modelId) {
const providerKey = String(provider || '').toLowerCase();
const model = String(modelId || '').toLowerCase();

if (providerKey === 'anthropic') {
if (model.includes('opus')) return 5;
if (model.includes('sonnet')) return 4;
if (model.includes('haiku')) return 3;
return 1;
}

if (providerKey === 'openai' || providerKey === 'copilot') {
if (/gpt-5(?:[.\-]|$)/i.test(model)) return 5;
if (/gpt-4(?:[.\-]|$)/i.test(model) || model.includes('gpt-4o')) return 4;
if (model.includes('gpt-3.5')) return 3;
return 1;
}

return null;
}

function getTierSortedModels(provider, models) {
if (!Array.isArray(models) || models.length === 0) return [];

const unique = [...new Set(models.filter(m => typeof m === 'string' && m.length > 0))];
if (unique.length === 0) return [];

const ranked = unique.map(model => ({
model,
tier: getModelCapabilityTier(provider, model),
}));

const hasTiering = ranked.some(entry => Number.isFinite(entry.tier));
ranked.sort((a, b) => {
if (!hasTiering) return a.model.localeCompare(b.model);
return (b.tier - a.tier) || a.model.localeCompare(b.model);
});
return ranked;
}

// ── buildRequest ──────────────────────────────────────────────────────────────
/**
* Shared HTTP/HTTPS request setup used by fetchJson and httpProbe.
Expand Down Expand Up @@ -225,6 +265,8 @@ module.exports = {
fetchJson,
httpProbe,
extractModelIds,
getModelCapabilityTier,
getTierSortedModels,
buildModelsJson,
writeModelsJson,
};
160 changes: 144 additions & 16 deletions containers/api-proxy/model-resolver.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
* case-insensitive, and sorted by semver semantics (highest version first).
*/

const { getTierSortedModels } = require('./model-discovery');

const DEFAULT_MODEL_FALLBACK = Object.freeze({
enabled: true,
strategy: 'middle_power',
});

/**
* Parse model aliases configuration from a raw JSON string.
*
Expand All @@ -37,12 +44,23 @@ function parseModelAliases(rawConfig) {
if (!parsed || typeof parsed !== 'object' || Array.isArray(parsed)) return null;
if (!parsed.models || typeof parsed.models !== 'object' || Array.isArray(parsed.models)) return null;

// Validate structure: each value must be an array of strings
// Validate structure: each value must be either:
// - string[] (legacy alias syntax)
// - { patterns: string[], fallback?: boolean } (extended alias syntax)
for (const [, value] of Object.entries(parsed.models)) {
if (!Array.isArray(value)) return null;
for (const entry of value) {
if (Array.isArray(value)) {
for (const entry of value) {
if (typeof entry !== 'string') return null;
}
continue;
}

if (!value || typeof value !== 'object' || Array.isArray(value)) return null;
if (!Array.isArray(value.patterns)) return null;
for (const entry of value.patterns) {
if (typeof entry !== 'string') return null;
}
if (value.fallback !== undefined && typeof value.fallback !== 'boolean') return null;
}

return { models: parsed.models };
Expand Down Expand Up @@ -103,6 +121,67 @@ function compareByVersion(a, b) {
return a.localeCompare(b); // Lexicographic fallback
}

function normalizeFallbackConfig(modelFallbackConfig) {
const config = modelFallbackConfig && typeof modelFallbackConfig === 'object'
? modelFallbackConfig
: DEFAULT_MODEL_FALLBACK;
return {
enabled: config.enabled !== false,
strategy: config.strategy || 'middle_power',
};
}

function resolveAliasDefinition(rawAlias) {
if (Array.isArray(rawAlias)) {
return { patterns: rawAlias, fallback: true };
}
if (!rawAlias || typeof rawAlias !== 'object' || Array.isArray(rawAlias)) {
return { patterns: [], fallback: true };
}
return {
patterns: Array.isArray(rawAlias.patterns) ? rawAlias.patterns : [],
fallback: rawAlias.fallback !== false,
};
}

function inferModelFamilyPrefix(requestedModel) {
const key = String(requestedModel || '').toLowerCase();
const gptFamily = key.match(/^(gpt-\d+(?:\.\d+)?)/)?.[1];
if (gptFamily) return gptFamily;
if (key.includes('claude')) return 'claude';
if (key.includes('gemini')) return 'gemini';
return null;
}

function selectMiddlePowerFallback(requestedModel, availableModels, currentProvider, reason, modelFallbackConfig) {
const fallbackConfig = normalizeFallbackConfig(modelFallbackConfig);
if (!fallbackConfig.enabled || fallbackConfig.strategy !== 'middle_power') return null;

const providerModels = Array.isArray(availableModels[currentProvider]) ? availableModels[currentProvider] : [];
if (providerModels.length === 0) return null;

const familyPrefix = inferModelFamilyPrefix(requestedModel);
const familyCandidates = familyPrefix
? providerModels.filter(model => model.toLowerCase().startsWith(familyPrefix))
: [];
const selectedPool = familyCandidates.length > 0 ? familyCandidates : providerModels;
const sortedCandidates = getTierSortedModels(currentProvider, selectedPool);
if (sortedCandidates.length === 0) return null;

const medianIndex = Math.floor((sortedCandidates.length - 1) / 2);
return {
resolvedModel: sortedCandidates[medianIndex].model,
fallback: {
activated: true,
reason,
selection_method: 'middle_power_median',
available_models_count: providerModels.length,
used_family_filter: familyCandidates.length > 0,
candidates: sortedCandidates,
},
};
}

/**
* Resolve a model name through the alias chain for a given provider.
*
Expand All @@ -115,15 +194,17 @@ function compareByVersion(a, b) {
* 3. Collect all candidates, sort by version (highest first), return the best match
*
* @param {string} requestedModel - Model name from the request body (or "" for default)
* @param {Record<string, string[]>} aliases - Alias map from parseModelAliases()
* @param {Record<string, string[]|{patterns: string[], fallback?: boolean}>} aliases - Alias map from parseModelAliases()
* @param {Record<string, string[]|null>} availableModels - Cached provider models
* @param {string} currentProvider - Provider handling this request (e.g. "copilot")
* @param {string[]} [chain=[]] - Accumulates visited alias names for loop detection
* @returns {{ resolvedModel: string, log: string[] } | null}
* @param {{ enabled?: boolean, strategy?: string }} [modelFallbackConfig]
* @returns {{ resolvedModel: string, log: string[], fallback?: object } | null}
*/
function resolveModel(requestedModel, aliases, availableModels, currentProvider, chain = []) {
function resolveModel(requestedModel, aliases, availableModels, currentProvider, chain = [], modelFallbackConfig = DEFAULT_MODEL_FALLBACK) {
const log = [];
const key = requestedModel.toLowerCase();
const fallbackConfig = normalizeFallbackConfig(modelFallbackConfig);

// ── Loop detection ────────────────────────────────────────────────────────
if (chain.includes(key)) {
Expand Down Expand Up @@ -155,7 +236,13 @@ function resolveModel(requestedModel, aliases, availableModels, currentProvider,
const direct = providerModels.find(m => m.toLowerCase() === key);
if (direct) {
log.push(`[model-resolver] direct match: "${requestedModel}" → "${direct}"`);
return { resolvedModel: direct, log };
return {
resolvedModel: direct,
log,
fallback: fallbackConfig.enabled
? { activated: false, selection_method: 'middle_power_median', reason: 'direct_match' }
: undefined,
};
}

// If a gpt-5.<minor> model is requested but unavailable, fall back to the
Expand All @@ -168,14 +255,33 @@ function resolveModel(requestedModel, aliases, availableModels, currentProvider,
const sorted = [...new Set(familyCandidates)].sort(compareByVersion);
const fallback = sorted[0];
log.push(`[model-resolver] requested model "${requestedModel}" not available, falling back to "${fallback}"`);
return { resolvedModel: fallback, log };
return {
resolvedModel: fallback,
log,
fallback: fallbackConfig.enabled
? { activated: false, selection_method: 'middle_power_median', reason: 'family_version_fallback' }
: undefined,
};
}
}
const middlePowerFallback = selectMiddlePowerFallback(
requestedModel,
availableModels,
currentProvider,
'no_alias_match_and_not_in_available_models',
fallbackConfig
);
if (middlePowerFallback) {
log.push(`[model-resolver] middle-power fallback: "${requestedModel}" → "${middlePowerFallback.resolvedModel}"`);
return { resolvedModel: middlePowerFallback.resolvedModel, log, fallback: middlePowerFallback.fallback };
}
// No match at all — cannot resolve.
return null;
}

const [aliasKey, patterns] = aliasEntry;
const [aliasKey, aliasRaw] = aliasEntry;
const aliasDefinition = resolveAliasDefinition(aliasRaw);
const patterns = aliasDefinition.patterns;
log.push(`[model-resolver] alias: "${requestedModel}" → [${patterns.join(', ')}]`);

// ── Expand each pattern ───────────────────────────────────────────────────
Expand All @@ -186,7 +292,7 @@ function resolveModel(requestedModel, aliases, availableModels, currentProvider,

if (slashIdx === -1) {
// Recursive alias reference (no provider prefix)
const sub = resolveModel(pattern, aliases, availableModels, currentProvider, newChain);
const sub = resolveModel(pattern, aliases, availableModels, currentProvider, newChain, fallbackConfig);
if (sub) {
log.push(...sub.log);
candidates.push(sub.resolvedModel);
Expand All @@ -209,6 +315,20 @@ function resolveModel(requestedModel, aliases, availableModels, currentProvider,

if (candidates.length === 0) {
log.push(`[model-resolver] no candidates found for "${aliasKey}" on provider "${currentProvider}"`);
const hasProviderPattern = patterns.some((pattern) => pattern.includes('/'));
if (aliasDefinition.fallback && hasProviderPattern) {
const middlePowerFallback = selectMiddlePowerFallback(
requestedModel,
availableModels,
currentProvider,
'no_alias_match_and_not_in_available_models',
fallbackConfig
);
if (middlePowerFallback) {
log.push(`[model-resolver] middle-power fallback: "${requestedModel}" → "${middlePowerFallback.resolvedModel}"`);
return { resolvedModel: middlePowerFallback.resolvedModel, log, fallback: middlePowerFallback.fallback };
}
}
return null;
}

Expand All @@ -225,7 +345,13 @@ function resolveModel(requestedModel, aliases, availableModels, currentProvider,
: '')
);

return { resolvedModel: resolved, log };
return {
resolvedModel: resolved,
log,
fallback: fallbackConfig.enabled
? { activated: false, selection_method: 'middle_power_median', reason: 'normal_resolution_succeeded' }
: undefined,
};
}

/**
Expand All @@ -236,11 +362,12 @@ function resolveModel(requestedModel, aliases, availableModels, currentProvider,
*
* @param {Buffer} body - Raw request body bytes
* @param {string} provider - Current provider (e.g. "copilot")
* @param {Record<string, string[]>} aliases - Parsed alias map
* @param {Record<string, string[]|{patterns: string[], fallback?: boolean}>} aliases - Parsed alias map
* @param {Record<string, string[]|null>} availableModels - Cached models per provider
* @returns {{ body: Buffer, originalModel: string, resolvedModel: string, log: string[] } | null}
* @param {{ enabled?: boolean, strategy?: string }} [modelFallbackConfig]
* @returns {{ body: Buffer, originalModel: string, resolvedModel: string, log: string[], fallback?: object } | null}
*/
function rewriteModelInBody(body, provider, aliases, availableModels) {
function rewriteModelInBody(body, provider, aliases, availableModels, modelFallbackConfig = DEFAULT_MODEL_FALLBACK) {
// Only attempt rewrite for non-empty bodies
if (!body || body.length === 0) return null;

Expand All @@ -256,7 +383,7 @@ function rewriteModelInBody(body, provider, aliases, availableModels) {
// Determine the requested model. If absent, try the default alias ("").
const originalModel = typeof parsed.model === 'string' ? parsed.model : '';

const resolution = resolveModel(originalModel, aliases, availableModels, provider);
const resolution = resolveModel(originalModel, aliases, availableModels, provider, [], modelFallbackConfig);
if (!resolution) return null;

const { resolvedModel, log } = resolution;
Expand All @@ -268,14 +395,15 @@ function rewriteModelInBody(body, provider, aliases, availableModels) {
parsed.model = resolvedModel;
const newBody = Buffer.from(JSON.stringify(parsed), 'utf8');

return { body: newBody, originalModel, resolvedModel, log };
return { body: newBody, originalModel, resolvedModel, log, fallback: resolution.fallback };
}

module.exports = {
parseModelAliases,
globMatch,
extractVersionNumbers,
compareByVersion,
selectMiddlePowerFallback,
resolveModel,
rewriteModelInBody,
};
Loading
Loading