Created
March 10, 2026 07:20
-
-
Save christophersjchow/ee8a3ecc25d5e8b1f5ad7a09b836affe to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import { | |
| createAwsGoogleWebIdentityProvider, | |
| } from "../src/awsGoogleWebIdentityProvider"; | |
| describe("createAwsGoogleWebIdentityProvider", () => { | |
| it("returns cached credentials when not near expiry", async () => { | |
| let currentTime = new Date("2026-03-10T00:00:00.000Z").getTime(); | |
| const getGoogleIdToken = jest.fn().mockResolvedValue("google-token-1"); | |
| const stsSend = jest.fn().mockResolvedValue({ | |
| Credentials: { | |
| AccessKeyId: "AKIA123", | |
| SecretAccessKey: "SECRET123", | |
| SessionToken: "SESSION123", | |
| Expiration: new Date(currentTime + 30 * 60 * 1000), | |
| }, | |
| }); | |
| const provider = createAwsGoogleWebIdentityProvider({ | |
| roleArn: "arn:aws:iam::123456789012:role/test-role", | |
| audience: "aws-ses-access", | |
| region: "ap-southeast-2", | |
| now: () => currentTime, | |
| getGoogleIdToken, | |
| stsClient: { send: stsSend }, | |
| }); | |
| const first = await provider(); | |
| const second = await provider(); | |
| expect(first).toEqual(second); | |
| expect(getGoogleIdToken).toHaveBeenCalledTimes(1); | |
| expect(stsSend).toHaveBeenCalledTimes(1); | |
| }); | |
| it("refreshes credentials when within refresh skew", async () => { | |
| let currentTime = new Date("2026-03-10T00:00:00.000Z").getTime(); | |
| const getGoogleIdToken = jest | |
| .fn() | |
| .mockResolvedValueOnce("google-token-1") | |
| .mockResolvedValueOnce("google-token-2"); | |
| const stsSend = jest | |
| .fn() | |
| .mockResolvedValueOnce({ | |
| Credentials: { | |
| AccessKeyId: "AKIA1", | |
| SecretAccessKey: "SECRET1", | |
| SessionToken: "SESSION1", | |
| Expiration: new Date(currentTime + 10 * 60 * 1000), | |
| }, | |
| }) | |
| .mockResolvedValueOnce({ | |
| Credentials: { | |
| AccessKeyId: "AKIA2", | |
| SecretAccessKey: "SECRET2", | |
| SessionToken: "SESSION2", | |
| Expiration: new Date(currentTime + 60 * 60 * 1000), | |
| }, | |
| }); | |
| const provider = createAwsGoogleWebIdentityProvider({ | |
| roleArn: "arn:aws:iam::123456789012:role/test-role", | |
| audience: "aws-ses-access", | |
| region: "ap-southeast-2", | |
| now: () => currentTime, | |
| refreshSkewMs: 5 * 60 * 1000, | |
| getGoogleIdToken, | |
| stsClient: { send: stsSend }, | |
| }); | |
| const first = await provider(); | |
| expect(first.accessKeyId).toBe("AKIA1"); | |
| currentTime += 6 * 60 * 1000; | |
| const second = await provider(); | |
| expect(second.accessKeyId).toBe("AKIA2"); | |
| expect(getGoogleIdToken).toHaveBeenCalledTimes(2); | |
| expect(stsSend).toHaveBeenCalledTimes(2); | |
| }); | |
| it("deduplicates concurrent refreshes", async () => { | |
| let currentTime = new Date("2026-03-10T00:00:00.000Z").getTime(); | |
| const getGoogleIdToken = jest.fn().mockResolvedValue("google-token"); | |
| const stsSend = jest.fn().mockImplementation(async () => { | |
| await new Promise((resolve) => setTimeout(resolve, 20)); | |
| return { | |
| Credentials: { | |
| AccessKeyId: "AKIA-CONCURRENT", | |
| SecretAccessKey: "SECRET-CONCURRENT", | |
| SessionToken: "SESSION-CONCURRENT", | |
| Expiration: new Date(currentTime + 60 * 60 * 1000), | |
| }, | |
| }; | |
| }); | |
| const provider = createAwsGoogleWebIdentityProvider({ | |
| roleArn: "arn:aws:iam::123456789012:role/test-role", | |
| audience: "aws-ses-access", | |
| region: "ap-southeast-2", | |
| now: () => currentTime, | |
| getGoogleIdToken, | |
| stsClient: { send: stsSend }, | |
| }); | |
| const [a, b, c] = await Promise.all([provider(), provider(), provider()]); | |
| expect(a.accessKeyId).toBe("AKIA-CONCURRENT"); | |
| expect(b.accessKeyId).toBe("AKIA-CONCURRENT"); | |
| expect(c.accessKeyId).toBe("AKIA-CONCURRENT"); | |
| expect(getGoogleIdToken).toHaveBeenCalledTimes(1); | |
| expect(stsSend).toHaveBeenCalledTimes(1); | |
| }); | |
| it("supports forceRefresh", async () => { | |
| let currentTime = new Date("2026-03-10T00:00:00.000Z").getTime(); | |
| const getGoogleIdToken = jest | |
| .fn() | |
| .mockResolvedValueOnce("google-token-1") | |
| .mockResolvedValueOnce("google-token-2"); | |
| const stsSend = jest | |
| .fn() | |
| .mockResolvedValueOnce({ | |
| Credentials: { | |
| AccessKeyId: "AKIA1", | |
| SecretAccessKey: "SECRET1", | |
| SessionToken: "SESSION1", | |
| Expiration: new Date(currentTime + 60 * 60 * 1000), | |
| }, | |
| }) | |
| .mockResolvedValueOnce({ | |
| Credentials: { | |
| AccessKeyId: "AKIA2", | |
| SecretAccessKey: "SECRET2", | |
| SessionToken: "SESSION2", | |
| Expiration: new Date(currentTime + 60 * 60 * 1000), | |
| }, | |
| }); | |
| const provider = createAwsGoogleWebIdentityProvider({ | |
| roleArn: "arn:aws:iam::123456789012:role/test-role", | |
| audience: "aws-ses-access", | |
| region: "ap-southeast-2", | |
| now: () => currentTime, | |
| getGoogleIdToken, | |
| stsClient: { send: stsSend }, | |
| }); | |
| const first = await provider(); | |
| expect(first.accessKeyId).toBe("AKIA1"); | |
| await provider.forceRefresh(); | |
| const second = await provider(); | |
| expect(second.accessKeyId).toBe("AKIA2"); | |
| expect(getGoogleIdToken).toHaveBeenCalledTimes(2); | |
| expect(stsSend).toHaveBeenCalledTimes(2); | |
| }); | |
| it("throws when STS returns incomplete credentials", async () => { | |
| const getGoogleIdToken = jest.fn().mockResolvedValue("google-token"); | |
| const stsSend = jest.fn().mockResolvedValue({ | |
| Credentials: { | |
| AccessKeyId: undefined, | |
| SecretAccessKey: undefined, | |
| }, | |
| }); | |
| const provider = createAwsGoogleWebIdentityProvider({ | |
| roleArn: "arn:aws:iam::123456789012:role/test-role", | |
| audience: "aws-ses-access", | |
| region: "ap-southeast-2", | |
| getGoogleIdToken, | |
| stsClient: { send: stsSend }, | |
| }); | |
| await expect(provider()).rejects.toThrow( | |
| "AWS STS did not return usable credentials" | |
| ); | |
| }); | |
| it("propagates Google token acquisition errors", async () => { | |
| const getGoogleIdToken = jest | |
| .fn() | |
| .mockRejectedValue(new Error("failed to get Google token")); | |
| const stsSend = jest.fn(); | |
| const provider = createAwsGoogleWebIdentityProvider({ | |
| roleArn: "arn:aws:iam::123456789012:role/test-role", | |
| audience: "aws-ses-access", | |
| region: "ap-southeast-2", | |
| getGoogleIdToken, | |
| stsClient: { send: stsSend }, | |
| }); | |
| await expect(provider()).rejects.toThrow("failed to get Google token"); | |
| expect(stsSend).not.toHaveBeenCalled(); | |
| }); | |
| it("propagates STS errors", async () => { | |
| const getGoogleIdToken = jest.fn().mockResolvedValue("google-token"); | |
| const stsSend = jest.fn().mockRejectedValue(new Error("sts failure")); | |
| const provider = createAwsGoogleWebIdentityProvider({ | |
| roleArn: "arn:aws:iam::123456789012:role/test-role", | |
| audience: "aws-ses-access", | |
| region: "ap-southeast-2", | |
| getGoogleIdToken, | |
| stsClient: { send: stsSend }, | |
| }); | |
| await expect(provider()).rejects.toThrow("sts failure"); | |
| }); | |
| }); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import { GoogleAuth } from "google-auth-library"; | |
| import { | |
| AssumeRoleWithWebIdentityCommand, | |
| STSClient, | |
| } from "@aws-sdk/client-sts"; | |
| import type { AwsCredentialIdentity, Provider } from "@aws-sdk/types"; | |
| export interface AwsGoogleWebIdentityProviderOptions { | |
| roleArn: string; | |
| audience: string; | |
| region: string; | |
| roleSessionNamePrefix?: string; | |
| durationSeconds?: number; | |
| refreshSkewMs?: number; | |
| now?: () => number; | |
| getGoogleIdToken?: (audience: string) => Promise<string>; | |
| stsClient?: { | |
| send: ( | |
| command: AssumeRoleWithWebIdentityCommand | |
| ) => Promise<{ | |
| Credentials?: { | |
| AccessKeyId?: string; | |
| SecretAccessKey?: string; | |
| SessionToken?: string; | |
| Expiration?: Date; | |
| }; | |
| }>; | |
| }; | |
| } | |
| export interface RefreshableAwsCredentialProvider | |
| extends Provider<AwsCredentialIdentity> { | |
| forceRefresh: () => Promise<void>; | |
| } | |
| const DEFAULT_REFRESH_SKEW_MS = 5 * 60 * 1000; | |
| const DEFAULT_DURATION_SECONDS = 3600; | |
| const DEFAULT_ROLE_SESSION_PREFIX = "gcp-web-identity"; | |
| function defaultNow(): number { | |
| return Date.now(); | |
| } | |
| export async function getGoogleIdTokenWithLibrary( | |
| audience: string | |
| ): Promise<string> { | |
| const auth = new GoogleAuth(); | |
| const client = await auth.getIdTokenClient(audience); | |
| const headers = await client.getRequestHeaders(); | |
| const authHeader = headers.Authorization ?? headers.authorization; | |
| if (!authHeader || !authHeader.startsWith("Bearer ")) { | |
| throw new Error("Failed to obtain Google ID token from google-auth-library"); | |
| } | |
| return authHeader.slice("Bearer ".length); | |
| } | |
| function buildStsClient(region: string): STSClient { | |
| return new STSClient({ region }); | |
| } | |
| function toAwsCredentialIdentity(credentials: { | |
| AccessKeyId?: string; | |
| SecretAccessKey?: string; | |
| SessionToken?: string; | |
| Expiration?: Date; | |
| }): AwsCredentialIdentity { | |
| if (!credentials.AccessKeyId || !credentials.SecretAccessKey) { | |
| throw new Error("AWS STS did not return usable credentials"); | |
| } | |
| return { | |
| accessKeyId: credentials.AccessKeyId, | |
| secretAccessKey: credentials.SecretAccessKey, | |
| sessionToken: credentials.SessionToken, | |
| expiration: credentials.Expiration, | |
| }; | |
| } | |
| function isCredentialUsable( | |
| creds: AwsCredentialIdentity | undefined, | |
| refreshSkewMs: number, | |
| now: number | |
| ): boolean { | |
| if (!creds) { | |
| return false; | |
| } | |
| if (!creds.expiration) { | |
| return true; | |
| } | |
| return creds.expiration.getTime() - refreshSkewMs > now; | |
| } | |
| export function createAwsGoogleWebIdentityProvider( | |
| options: AwsGoogleWebIdentityProviderOptions | |
| ): RefreshableAwsCredentialProvider { | |
| const { | |
| roleArn, | |
| audience, | |
| region, | |
| roleSessionNamePrefix = DEFAULT_ROLE_SESSION_PREFIX, | |
| durationSeconds = DEFAULT_DURATION_SECONDS, | |
| refreshSkewMs = DEFAULT_REFRESH_SKEW_MS, | |
| now = defaultNow, | |
| getGoogleIdToken = getGoogleIdTokenWithLibrary, | |
| stsClient = buildStsClient(region), | |
| } = options; | |
| let cachedCredentials: AwsCredentialIdentity | undefined; | |
| let inflightRefresh: Promise<AwsCredentialIdentity> | undefined; | |
| const refresh = async (): Promise<AwsCredentialIdentity> => { | |
| const webIdentityToken = await getGoogleIdToken(audience); | |
| const response = await stsClient.send( | |
| new AssumeRoleWithWebIdentityCommand({ | |
| RoleArn: roleArn, | |
| RoleSessionName: `${roleSessionNamePrefix}-${now()}`, | |
| WebIdentityToken: webIdentityToken, | |
| DurationSeconds: durationSeconds, | |
| }) | |
| ); | |
| cachedCredentials = toAwsCredentialIdentity(response.Credentials ?? {}); | |
| return cachedCredentials; | |
| }; | |
| const provider = (async (): Promise<AwsCredentialIdentity> => { | |
| const currentNow = now(); | |
| if (isCredentialUsable(cachedCredentials, refreshSkewMs, currentNow)) { | |
| return cachedCredentials!; | |
| } | |
| if (!inflightRefresh) { | |
| inflightRefresh = refresh().finally(() => { | |
| inflightRefresh = undefined; | |
| }); | |
| } | |
| return inflightRefresh; | |
| }) as RefreshableAwsCredentialProvider; | |
| provider.forceRefresh = async (): Promise<void> => { | |
| cachedCredentials = undefined; | |
| if (!inflightRefresh) { | |
| inflightRefresh = refresh().finally(() => { | |
| inflightRefresh = undefined; | |
| }); | |
| } | |
| await inflightRefresh; | |
| }; | |
| return provider; | |
| } | |
| export interface SesClientFactoryOptions | |
| extends AwsGoogleWebIdentityProviderOptions {} | |
| export function createSesCredentialProvider( | |
| options: SesClientFactoryOptions | |
| ): Provider<AwsCredentialIdentity> { | |
| return createAwsGoogleWebIdentityProvider(options); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment