Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save christophersjchow/ee8a3ecc25d5e8b1f5ad7a09b836affe to your computer and use it in GitHub Desktop.

Select an option

Save christophersjchow/ee8a3ecc25d5e8b1f5ad7a09b836affe to your computer and use it in GitHub Desktop.
import {
createAwsGoogleWebIdentityProvider,
} from "../src/awsGoogleWebIdentityProvider";
describe("createAwsGoogleWebIdentityProvider", () => {
it("returns cached credentials when not near expiry", async () => {
let currentTime = new Date("2026-03-10T00:00:00.000Z").getTime();
const getGoogleIdToken = jest.fn().mockResolvedValue("google-token-1");
const stsSend = jest.fn().mockResolvedValue({
Credentials: {
AccessKeyId: "AKIA123",
SecretAccessKey: "SECRET123",
SessionToken: "SESSION123",
Expiration: new Date(currentTime + 30 * 60 * 1000),
},
});
const provider = createAwsGoogleWebIdentityProvider({
roleArn: "arn:aws:iam::123456789012:role/test-role",
audience: "aws-ses-access",
region: "ap-southeast-2",
now: () => currentTime,
getGoogleIdToken,
stsClient: { send: stsSend },
});
const first = await provider();
const second = await provider();
expect(first).toEqual(second);
expect(getGoogleIdToken).toHaveBeenCalledTimes(1);
expect(stsSend).toHaveBeenCalledTimes(1);
});
it("refreshes credentials when within refresh skew", async () => {
let currentTime = new Date("2026-03-10T00:00:00.000Z").getTime();
const getGoogleIdToken = jest
.fn()
.mockResolvedValueOnce("google-token-1")
.mockResolvedValueOnce("google-token-2");
const stsSend = jest
.fn()
.mockResolvedValueOnce({
Credentials: {
AccessKeyId: "AKIA1",
SecretAccessKey: "SECRET1",
SessionToken: "SESSION1",
Expiration: new Date(currentTime + 10 * 60 * 1000),
},
})
.mockResolvedValueOnce({
Credentials: {
AccessKeyId: "AKIA2",
SecretAccessKey: "SECRET2",
SessionToken: "SESSION2",
Expiration: new Date(currentTime + 60 * 60 * 1000),
},
});
const provider = createAwsGoogleWebIdentityProvider({
roleArn: "arn:aws:iam::123456789012:role/test-role",
audience: "aws-ses-access",
region: "ap-southeast-2",
now: () => currentTime,
refreshSkewMs: 5 * 60 * 1000,
getGoogleIdToken,
stsClient: { send: stsSend },
});
const first = await provider();
expect(first.accessKeyId).toBe("AKIA1");
currentTime += 6 * 60 * 1000;
const second = await provider();
expect(second.accessKeyId).toBe("AKIA2");
expect(getGoogleIdToken).toHaveBeenCalledTimes(2);
expect(stsSend).toHaveBeenCalledTimes(2);
});
it("deduplicates concurrent refreshes", async () => {
let currentTime = new Date("2026-03-10T00:00:00.000Z").getTime();
const getGoogleIdToken = jest.fn().mockResolvedValue("google-token");
const stsSend = jest.fn().mockImplementation(async () => {
await new Promise((resolve) => setTimeout(resolve, 20));
return {
Credentials: {
AccessKeyId: "AKIA-CONCURRENT",
SecretAccessKey: "SECRET-CONCURRENT",
SessionToken: "SESSION-CONCURRENT",
Expiration: new Date(currentTime + 60 * 60 * 1000),
},
};
});
const provider = createAwsGoogleWebIdentityProvider({
roleArn: "arn:aws:iam::123456789012:role/test-role",
audience: "aws-ses-access",
region: "ap-southeast-2",
now: () => currentTime,
getGoogleIdToken,
stsClient: { send: stsSend },
});
const [a, b, c] = await Promise.all([provider(), provider(), provider()]);
expect(a.accessKeyId).toBe("AKIA-CONCURRENT");
expect(b.accessKeyId).toBe("AKIA-CONCURRENT");
expect(c.accessKeyId).toBe("AKIA-CONCURRENT");
expect(getGoogleIdToken).toHaveBeenCalledTimes(1);
expect(stsSend).toHaveBeenCalledTimes(1);
});
it("supports forceRefresh", async () => {
let currentTime = new Date("2026-03-10T00:00:00.000Z").getTime();
const getGoogleIdToken = jest
.fn()
.mockResolvedValueOnce("google-token-1")
.mockResolvedValueOnce("google-token-2");
const stsSend = jest
.fn()
.mockResolvedValueOnce({
Credentials: {
AccessKeyId: "AKIA1",
SecretAccessKey: "SECRET1",
SessionToken: "SESSION1",
Expiration: new Date(currentTime + 60 * 60 * 1000),
},
})
.mockResolvedValueOnce({
Credentials: {
AccessKeyId: "AKIA2",
SecretAccessKey: "SECRET2",
SessionToken: "SESSION2",
Expiration: new Date(currentTime + 60 * 60 * 1000),
},
});
const provider = createAwsGoogleWebIdentityProvider({
roleArn: "arn:aws:iam::123456789012:role/test-role",
audience: "aws-ses-access",
region: "ap-southeast-2",
now: () => currentTime,
getGoogleIdToken,
stsClient: { send: stsSend },
});
const first = await provider();
expect(first.accessKeyId).toBe("AKIA1");
await provider.forceRefresh();
const second = await provider();
expect(second.accessKeyId).toBe("AKIA2");
expect(getGoogleIdToken).toHaveBeenCalledTimes(2);
expect(stsSend).toHaveBeenCalledTimes(2);
});
it("throws when STS returns incomplete credentials", async () => {
const getGoogleIdToken = jest.fn().mockResolvedValue("google-token");
const stsSend = jest.fn().mockResolvedValue({
Credentials: {
AccessKeyId: undefined,
SecretAccessKey: undefined,
},
});
const provider = createAwsGoogleWebIdentityProvider({
roleArn: "arn:aws:iam::123456789012:role/test-role",
audience: "aws-ses-access",
region: "ap-southeast-2",
getGoogleIdToken,
stsClient: { send: stsSend },
});
await expect(provider()).rejects.toThrow(
"AWS STS did not return usable credentials"
);
});
it("propagates Google token acquisition errors", async () => {
const getGoogleIdToken = jest
.fn()
.mockRejectedValue(new Error("failed to get Google token"));
const stsSend = jest.fn();
const provider = createAwsGoogleWebIdentityProvider({
roleArn: "arn:aws:iam::123456789012:role/test-role",
audience: "aws-ses-access",
region: "ap-southeast-2",
getGoogleIdToken,
stsClient: { send: stsSend },
});
await expect(provider()).rejects.toThrow("failed to get Google token");
expect(stsSend).not.toHaveBeenCalled();
});
it("propagates STS errors", async () => {
const getGoogleIdToken = jest.fn().mockResolvedValue("google-token");
const stsSend = jest.fn().mockRejectedValue(new Error("sts failure"));
const provider = createAwsGoogleWebIdentityProvider({
roleArn: "arn:aws:iam::123456789012:role/test-role",
audience: "aws-ses-access",
region: "ap-southeast-2",
getGoogleIdToken,
stsClient: { send: stsSend },
});
await expect(provider()).rejects.toThrow("sts failure");
});
});
import { GoogleAuth } from "google-auth-library";
import {
AssumeRoleWithWebIdentityCommand,
STSClient,
} from "@aws-sdk/client-sts";
import type { AwsCredentialIdentity, Provider } from "@aws-sdk/types";
export interface AwsGoogleWebIdentityProviderOptions {
roleArn: string;
audience: string;
region: string;
roleSessionNamePrefix?: string;
durationSeconds?: number;
refreshSkewMs?: number;
now?: () => number;
getGoogleIdToken?: (audience: string) => Promise<string>;
stsClient?: {
send: (
command: AssumeRoleWithWebIdentityCommand
) => Promise<{
Credentials?: {
AccessKeyId?: string;
SecretAccessKey?: string;
SessionToken?: string;
Expiration?: Date;
};
}>;
};
}
export interface RefreshableAwsCredentialProvider
extends Provider<AwsCredentialIdentity> {
forceRefresh: () => Promise<void>;
}
const DEFAULT_REFRESH_SKEW_MS = 5 * 60 * 1000;
const DEFAULT_DURATION_SECONDS = 3600;
const DEFAULT_ROLE_SESSION_PREFIX = "gcp-web-identity";
function defaultNow(): number {
return Date.now();
}
export async function getGoogleIdTokenWithLibrary(
audience: string
): Promise<string> {
const auth = new GoogleAuth();
const client = await auth.getIdTokenClient(audience);
const headers = await client.getRequestHeaders();
const authHeader = headers.Authorization ?? headers.authorization;
if (!authHeader || !authHeader.startsWith("Bearer ")) {
throw new Error("Failed to obtain Google ID token from google-auth-library");
}
return authHeader.slice("Bearer ".length);
}
function buildStsClient(region: string): STSClient {
return new STSClient({ region });
}
function toAwsCredentialIdentity(credentials: {
AccessKeyId?: string;
SecretAccessKey?: string;
SessionToken?: string;
Expiration?: Date;
}): AwsCredentialIdentity {
if (!credentials.AccessKeyId || !credentials.SecretAccessKey) {
throw new Error("AWS STS did not return usable credentials");
}
return {
accessKeyId: credentials.AccessKeyId,
secretAccessKey: credentials.SecretAccessKey,
sessionToken: credentials.SessionToken,
expiration: credentials.Expiration,
};
}
function isCredentialUsable(
creds: AwsCredentialIdentity | undefined,
refreshSkewMs: number,
now: number
): boolean {
if (!creds) {
return false;
}
if (!creds.expiration) {
return true;
}
return creds.expiration.getTime() - refreshSkewMs > now;
}
export function createAwsGoogleWebIdentityProvider(
options: AwsGoogleWebIdentityProviderOptions
): RefreshableAwsCredentialProvider {
const {
roleArn,
audience,
region,
roleSessionNamePrefix = DEFAULT_ROLE_SESSION_PREFIX,
durationSeconds = DEFAULT_DURATION_SECONDS,
refreshSkewMs = DEFAULT_REFRESH_SKEW_MS,
now = defaultNow,
getGoogleIdToken = getGoogleIdTokenWithLibrary,
stsClient = buildStsClient(region),
} = options;
let cachedCredentials: AwsCredentialIdentity | undefined;
let inflightRefresh: Promise<AwsCredentialIdentity> | undefined;
const refresh = async (): Promise<AwsCredentialIdentity> => {
const webIdentityToken = await getGoogleIdToken(audience);
const response = await stsClient.send(
new AssumeRoleWithWebIdentityCommand({
RoleArn: roleArn,
RoleSessionName: `${roleSessionNamePrefix}-${now()}`,
WebIdentityToken: webIdentityToken,
DurationSeconds: durationSeconds,
})
);
cachedCredentials = toAwsCredentialIdentity(response.Credentials ?? {});
return cachedCredentials;
};
const provider = (async (): Promise<AwsCredentialIdentity> => {
const currentNow = now();
if (isCredentialUsable(cachedCredentials, refreshSkewMs, currentNow)) {
return cachedCredentials!;
}
if (!inflightRefresh) {
inflightRefresh = refresh().finally(() => {
inflightRefresh = undefined;
});
}
return inflightRefresh;
}) as RefreshableAwsCredentialProvider;
provider.forceRefresh = async (): Promise<void> => {
cachedCredentials = undefined;
if (!inflightRefresh) {
inflightRefresh = refresh().finally(() => {
inflightRefresh = undefined;
});
}
await inflightRefresh;
};
return provider;
}
export interface SesClientFactoryOptions
extends AwsGoogleWebIdentityProviderOptions {}
export function createSesCredentialProvider(
options: SesClientFactoryOptions
): Provider<AwsCredentialIdentity> {
return createAwsGoogleWebIdentityProvider(options);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment