Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 29 additions & 105 deletions containers/api-proxy/oidc-token-provider.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,11 @@
*/

const { mintGitHubOidcToken, httpPost } = require('./github-oidc');
const { logRequest } = require('./logging');

// Refresh at 75% of token lifetime (Azure tokens typically last 3600s)
const REFRESH_FACTOR = 0.75;
// Minimum seconds before expiry to trigger refresh
const MIN_REFRESH_MARGIN_SECS = 300;
// Retry delay after failed refresh (ms)
const REFRESH_RETRY_DELAY_MS = 30_000;
// Maximum retries for initial token acquisition
const MAX_INIT_RETRIES = 3;
const {
BaseOidcTokenProvider,
REFRESH_FACTOR,
MIN_REFRESH_MARGIN_SECS,
} = require('./oidc-token-provider-base');

/**
* @typedef {Object} OidcTokenProviderConfig
Expand All @@ -39,28 +34,22 @@ const MAX_INIT_RETRIES = 3;
* @property {number} [maxInitRetries] - Maximum retries for initial token acquisition (default: 3)
*/

class OidcTokenProvider {
class OidcTokenProvider extends BaseOidcTokenProvider {
/**
* @param {OidcTokenProviderConfig} config
*/
constructor(config) {
super('oidc', config);
this._requestUrl = config.requestUrl;
this._requestToken = config.requestToken;
this._tenantId = config.tenantId;
this._clientId = config.clientId;
this._oidcAudience = config.oidcAudience || 'api://AzureADTokenExchange';
this._azureScope = config.azureScope || 'https://cognitiveservices.azure.com/.default';
this._loginHost = this._resolveLoginHost(config.azureCloud);
this._retryDelayMs = config.retryDelayMs ?? REFRESH_RETRY_DELAY_MS;
this._maxInitRetries = config.maxInitRetries ?? MAX_INIT_RETRIES;

// Token state
this._cachedToken = null;
this._expiresAt = 0; // Unix timestamp (seconds)
this._refreshTimer = null;
this._refreshInFlight = null;
this._initialized = false;
this._initError = null;
}

/**
Expand All @@ -76,44 +65,6 @@ class OidcTokenProvider {
}
}

/**
* Initialize the token provider by acquiring the first token.
* Must be called (and awaited) before getToken() is usable.
* @returns {Promise<void>}
*/
async initialize() {
for (let attempt = 1; attempt <= this._maxInitRetries; attempt++) {
try {
await this._refreshToken();
this._initialized = true;
this._initError = null;
logRequest('info', 'oidc_init_success', {
tenant_id: this._tenantId,
client_id: this._clientId,
scope: this._azureScope,
expires_in_secs: this._expiresAt - Math.floor(Date.now() / 1000),
});
return;
} catch (err) {
this._initError = err;
logRequest('warn', 'oidc_init_retry', {
attempt,
max_retries: this._maxInitRetries,
error: err.message,
});
if (attempt < this._maxInitRetries) {
await this._sleep(this._retryDelayMs * attempt);
}
}
}
// All retries failed — log but don't throw; getToken() will return null
logRequest('error', 'oidc_init_failed', {
error: this._initError?.message,
tenant_id: this._tenantId,
client_id: this._clientId,
});
}

/**
* Get the current cached token synchronously.
* Returns null if no valid token is available.
Expand All @@ -131,25 +82,6 @@ class OidcTokenProvider {
return null;
}

/**
* Whether the provider has a usable token.
* @returns {boolean}
*/
isReady() {
const now = Math.floor(Date.now() / 1000);
return !!(this._cachedToken && this._expiresAt > now);
}

/**
* Stop background refresh timers.
*/
shutdown() {
if (this._refreshTimer) {
clearTimeout(this._refreshTimer);
this._refreshTimer = null;
}
}

/**
* Mint a GitHub OIDC token with the specified audience.
* @returns {Promise<string>} The GitHub-issued JWT
Expand Down Expand Up @@ -223,33 +155,6 @@ class OidcTokenProvider {
this._scheduleRefresh(Math.floor(refreshInSecs * 1000));
}

/**
* Schedule a background token refresh.
* @param {number} delayMs
*/
_scheduleRefresh(delayMs) {
if (this._refreshTimer) clearTimeout(this._refreshTimer);
this._refreshTimer = setTimeout(() => {
this._refreshInFlight = this._refreshToken()
.then(() => {
logRequest('info', 'oidc_refresh_success', {
expires_in_secs: this._expiresAt - Math.floor(Date.now() / 1000),
});
})
.catch((err) => {
logRequest('error', 'oidc_refresh_failed', { error: err.message });
// Retry after delay if token is still valid
const now = Math.floor(Date.now() / 1000);
if (this._expiresAt > now) {
this._scheduleRefresh(this._retryDelayMs);
}
})
.finally(() => { this._refreshInFlight = null; });
}, delayMs);
// Don't let refresh timer keep the process alive
if (this._refreshTimer.unref) this._refreshTimer.unref();
}

/**
* HTTP POST helper.
* @param {string} url
Expand All @@ -261,9 +166,28 @@ class OidcTokenProvider {
return httpPost(url, body, headers);
}

/** @param {number} ms */
_sleep(ms) {
return new Promise(resolve => setTimeout(resolve, ms));
async _doRefresh() {
await this._refreshToken();
}

_getCachedValue() {
return this._cachedToken;
}

_getInitSuccessLogContext() {
return {
tenant_id: this._tenantId,
client_id: this._clientId,
scope: this._azureScope,
expires_in_secs: this._expiresAt - Math.floor(Date.now() / 1000),
};
}

_getInitFailureLogContext() {
return {
tenant_id: this._tenantId,
client_id: this._clientId,
};
}
}

Expand Down
18 changes: 18 additions & 0 deletions containers/api-proxy/oidc-token-provider.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,24 @@ describe('OidcTokenProvider', () => {
expect(provider._scheduleRefresh).toHaveBeenCalledWith(0);
provider.shutdown();
});

it('should not trigger refresh after shutdown', async () => {
const provider = new OidcTokenProvider({
requestUrl: 'http://localhost/token',
requestToken: 'test',
tenantId: 'test',
clientId: 'test',
});

provider._refreshToken = jest.fn().mockResolvedValue();
provider.shutdown();

expect(provider.getToken()).toBeNull();
await new Promise(resolve => setTimeout(resolve, 20));

expect(provider._refreshToken).not.toHaveBeenCalled();
expect(provider._refreshTimer).toBeNull();
});
});

describe('OpenAI adapter with OIDC', () => {
Expand Down
Loading