Skip to content

Commit

Permalink
Implemented getManagedIdentitySource
Browse files Browse the repository at this point in the history
  • Loading branch information
Robbie-Microsoft committed May 10, 2024
1 parent 8ba150d commit aa22e65
Show file tree
Hide file tree
Showing 12 changed files with 205 additions and 27 deletions.
9 changes: 8 additions & 1 deletion lib/msal-node/src/client/ManagedIdentityApplication.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ import { ClientCredentialClient } from "./ClientCredentialClient";
import { ManagedIdentityClient } from "./ManagedIdentityClient";
import { ManagedIdentityRequestParams } from "../request/ManagedIdentityRequestParams";
import { NodeStorage } from "../cache/NodeStorage";
import { DEFAULT_AUTHORITY_FOR_MANAGED_IDENTITY } from "../utils/Constants";
import {
AzureIdentitySdkManagedIdentitySourceNames,
DEFAULT_AUTHORITY_FOR_MANAGED_IDENTITY,
} from "../utils/Constants";

/**
* Class to initialize a managed identity and identify the service
Expand Down Expand Up @@ -183,4 +186,8 @@ export class ManagedIdentityApplication {
);
}
}

public getManagedIdentitySource(): AzureIdentitySdkManagedIdentitySourceNames {
return this.managedIdentityClient.getManagedIdentitySource();
}
}
35 changes: 35 additions & 0 deletions lib/msal-node/src/client/ManagedIdentityClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import { ManagedIdentityRequest } from "../request/ManagedIdentityRequest";
import { ManagedIdentityId } from "../config/ManagedIdentityId";
import { NodeStorage } from "../cache/NodeStorage";
import { BaseManagedIdentitySource } from "./ManagedIdentitySources/BaseManagedIdentitySource";
import { AzureIdentitySdkManagedIdentitySourceNames } from "../utils/Constants";

/*
* Class to initialize a managed identity and identify the service.
Expand Down Expand Up @@ -73,6 +74,40 @@ export class ManagedIdentityClient {
);
}

private allEnvironmentVariablesAreDefined(
environmentVariables: Array<string | undefined>
): boolean {
return Object.values(environmentVariables).every(
(environmentVariable) => {
return environmentVariable !== undefined;
}
);
}

/**
* Determine the Managed Identity Source based on available environment variables. This API is consumed by Azure Identity SDK.
* @returns AzureIdentitySdkManagedIdentitySourceNames - Azure Identity SDK defined identifiers for the Managed Identity Sources
*/
public getManagedIdentitySource(): AzureIdentitySdkManagedIdentitySourceNames {
return this.allEnvironmentVariablesAreDefined(
ServiceFabric.getEnvironmentVariables()
)
? AzureIdentitySdkManagedIdentitySourceNames.SERVICE_FABRIC
: this.allEnvironmentVariablesAreDefined(
AppService.getEnvironmentVariables()
)
? AzureIdentitySdkManagedIdentitySourceNames.APP_SERVICE
: this.allEnvironmentVariablesAreDefined(
CloudShell.getEnvironmentVariables()
)
? AzureIdentitySdkManagedIdentitySourceNames.CLOUD_SHELL
: this.allEnvironmentVariablesAreDefined(
AzureArc.getEnvironmentVariables()
)
? AzureIdentitySdkManagedIdentitySourceNames.AZURE_ARC
: AzureIdentitySdkManagedIdentitySourceNames.IMDS;
}

/**
* Tries to create a managed identity source for all sources
* @returns the managed identity Source
Expand Down
19 changes: 13 additions & 6 deletions lib/msal-node/src/client/ManagedIdentitySources/AppService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,7 @@ export class AppService extends BaseManagedIdentitySource {
this.identityHeader = identityHeader;
}

public static tryCreate(
logger: Logger,
nodeStorage: NodeStorage,
networkClient: INetworkModule,
cryptoProvider: CryptoProvider
): AppService | null {
public static getEnvironmentVariables(): Array<string | undefined> {
const identityEndpoint: string | undefined =
process.env[
ManagedIdentityEnvironmentVariableNames.IDENTITY_ENDPOINT
Expand All @@ -58,6 +53,18 @@ export class AppService extends BaseManagedIdentitySource {
ManagedIdentityEnvironmentVariableNames.IDENTITY_HEADER
];

return [identityEndpoint, identityHeader];
}

public static tryCreate(
logger: Logger,
nodeStorage: NodeStorage,
networkClient: INetworkModule,
cryptoProvider: CryptoProvider
): AppService | null {
const [identityEndpoint, identityHeader] =
AppService.getEnvironmentVariables();

// if either of the identity endpoint or identity header variables are undefined, this MSI provider is unavailable.
if (!identityEndpoint || !identityHeader) {
logger.info(
Expand Down
19 changes: 13 additions & 6 deletions lib/msal-node/src/client/ManagedIdentitySources/AzureArc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,26 @@ export class AzureArc extends BaseManagedIdentitySource {
this.identityEndpoint = identityEndpoint;
}

public static getEnvironmentVariables(): Array<string | undefined> {
const identityEndpoint: string | undefined =
process.env[
ManagedIdentityEnvironmentVariableNames.IDENTITY_ENDPOINT
];
const imdsEndpoint: string | undefined =
process.env[ManagedIdentityEnvironmentVariableNames.IMDS_ENDPOINT];

return [identityEndpoint, imdsEndpoint];
}

public static tryCreate(
logger: Logger,
nodeStorage: NodeStorage,
networkClient: INetworkModule,
cryptoProvider: CryptoProvider,
managedIdentityId: ManagedIdentityId
): AzureArc | null {
const identityEndpoint: string | undefined =
process.env[
ManagedIdentityEnvironmentVariableNames.IDENTITY_ENDPOINT
];
const imdsEndpoint: string | undefined =
process.env[ManagedIdentityEnvironmentVariableNames.IMDS_ENDPOINT];
const [identityEndpoint, imdsEndpoint] =
AzureArc.getEnvironmentVariables();

// if either of the identity or imds endpoints are undefined, this MSI provider is unavailable.
if (!identityEndpoint || !imdsEndpoint) {
Expand Down
10 changes: 8 additions & 2 deletions lib/msal-node/src/client/ManagedIdentitySources/CloudShell.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,21 @@ export class CloudShell extends BaseManagedIdentitySource {
this.msiEndpoint = msiEndpoint;
}

public static getEnvironmentVariables(): Array<string | undefined> {
const msiEndpoint: string | undefined =
process.env[ManagedIdentityEnvironmentVariableNames.MSI_ENDPOINT];

return [msiEndpoint];
}

public static tryCreate(
logger: Logger,
nodeStorage: NodeStorage,
networkClient: INetworkModule,
cryptoProvider: CryptoProvider,
managedIdentityId: ManagedIdentityId
): CloudShell | null {
const msiEndpoint: string | undefined =
process.env[ManagedIdentityEnvironmentVariableNames.MSI_ENDPOINT];
const [msiEndpoint] = CloudShell.getEnvironmentVariables();

// if the msi endpoint environment variable is undefined, this MSI provider is unavailable.
if (!msiEndpoint) {
Expand Down
21 changes: 14 additions & 7 deletions lib/msal-node/src/client/ManagedIdentitySources/ServiceFabric.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,7 @@ export class ServiceFabric extends BaseManagedIdentitySource {
this.identityHeader = identityHeader;
}

public static tryCreate(
logger: Logger,
nodeStorage: NodeStorage,
networkClient: INetworkModule,
cryptoProvider: CryptoProvider,
managedIdentityId: ManagedIdentityId
): ServiceFabric | null {
public static getEnvironmentVariables(): Array<string | undefined> {
const identityEndpoint: string | undefined =
process.env[
ManagedIdentityEnvironmentVariableNames.IDENTITY_ENDPOINT
Expand All @@ -64,6 +58,19 @@ export class ServiceFabric extends BaseManagedIdentitySource {
.IDENTITY_SERVER_THUMBPRINT
];

return [identityEndpoint, identityHeader, identityServerThumbprint];
}

public static tryCreate(
logger: Logger,
nodeStorage: NodeStorage,
networkClient: INetworkModule,
cryptoProvider: CryptoProvider,
managedIdentityId: ManagedIdentityId
): ServiceFabric | null {
const [identityEndpoint, identityHeader, identityServerThumbprint] =
ServiceFabric.getEnvironmentVariables();

/*
* if either of the identity endpoint, identity header, or identity server thumbprint
* environment variables are undefined, this MSI provider is unavailable.
Expand Down
13 changes: 13 additions & 0 deletions lib/msal-node/src/utils/Constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ export const ManagedIdentitySourceNames = {
export type ManagedIdentitySourceNames =
(typeof ManagedIdentitySourceNames)[keyof typeof ManagedIdentitySourceNames];

/**
* Azure Identity SDK defined identifiers for the Managed Identity Sources
*/
export const AzureIdentitySdkManagedIdentitySourceNames = {
APP_SERVICE: "APP_SERVICE",
AZURE_ARC: "ARC",
CLOUD_SHELL: "CLOUD_SHELL",
IMDS: "DEFAULT_TO_VM",
SERVICE_FABRIC: "SERVICE_FABRIC",
} as const;
export type AzureIdentitySdkManagedIdentitySourceNames =
(typeof AzureIdentitySdkManagedIdentitySourceNames)[keyof typeof AzureIdentitySdkManagedIdentitySourceNames];

/**
* Managed Identity Ids
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ import {
} from "../../test_kit/ManagedIdentityTestUtils";
import { AuthenticationResult } from "@azure/msal-common";
import { ManagedIdentityClient } from "../../../src/client/ManagedIdentityClient";
import { ManagedIdentityEnvironmentVariableNames } from "../../../src/utils/Constants";
import {
AzureIdentitySdkManagedIdentitySourceNames,
ManagedIdentityEnvironmentVariableNames,
} from "../../../src/utils/Constants";

describe("Acquires a token successfully via an App Service Managed Identity", () => {
beforeAll(() => {
Expand Down Expand Up @@ -48,6 +51,9 @@ describe("Acquires a token successfully via an App Service Managed Identity", ()

const managedIdentityApplication: ManagedIdentityApplication =
new ManagedIdentityApplication(userAssignedClientIdConfig);
expect(managedIdentityApplication.getManagedIdentitySource()).toBe(
AzureIdentitySdkManagedIdentitySourceNames.APP_SERVICE
);

const networkManagedIdentityResult: AuthenticationResult =
await managedIdentityApplication.acquireToken(
Expand All @@ -65,6 +71,9 @@ describe("Acquires a token successfully via an App Service Managed Identity", ()
managedIdentityApplication = new ManagedIdentityApplication(
systemAssignedConfig
);
expect(managedIdentityApplication.getManagedIdentitySource()).toBe(
AzureIdentitySdkManagedIdentitySourceNames.APP_SERVICE
);
});

test("acquires a token", async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ import {
} from "../../../src/error/ManagedIdentityError";
import { ARC_API_VERSION } from "../../../src/client/ManagedIdentitySources/AzureArc";
import * as fs from "fs";
import { ManagedIdentityEnvironmentVariableNames } from "../../../src/utils/Constants";
import {
AzureIdentitySdkManagedIdentitySourceNames,
ManagedIdentityEnvironmentVariableNames,
} from "../../../src/utils/Constants";

jest.mock("fs");

Expand Down Expand Up @@ -64,6 +67,9 @@ describe("Acquires a token successfully via an Azure Arc Managed Identity", () =
managedIdentityApplication = new ManagedIdentityApplication(
systemAssignedConfig
);
expect(managedIdentityApplication.getManagedIdentitySource()).toBe(
AzureIdentitySdkManagedIdentitySourceNames.AZURE_ARC
);
});

test("acquires a token", async () => {
Expand Down Expand Up @@ -116,6 +122,9 @@ describe("Acquires a token successfully via an Azure Arc Managed Identity", () =
// managedIdentityIdParams will be omitted for system assigned
},
});
expect(managedIdentityApplication.getManagedIdentitySource()).toBe(
AzureIdentitySdkManagedIdentitySourceNames.AZURE_ARC
);

const networkErrorClient: ManagedIdentityNetworkErrorClient =
new ManagedIdentityNetworkErrorClient();
Expand Down Expand Up @@ -174,6 +183,9 @@ describe("Acquires a token successfully via an Azure Arc Managed Identity", () =

const managedIdentityApplication: ManagedIdentityApplication =
new ManagedIdentityApplication(userAssignedClientIdConfig);
expect(managedIdentityApplication.getManagedIdentitySource()).toBe(
AzureIdentitySdkManagedIdentitySourceNames.AZURE_ARC
);

await expect(
managedIdentityApplication.acquireToken(
Expand All @@ -199,6 +211,9 @@ describe("Acquires a token successfully via an Azure Arc Managed Identity", () =
// managedIdentityIdParams will be omitted for system assigned
},
});
expect(managedIdentityApplication.getManagedIdentitySource()).toBe(
AzureIdentitySdkManagedIdentitySourceNames.AZURE_ARC
);

await expect(
managedIdentityApplication.acquireToken(
Expand All @@ -223,6 +238,9 @@ describe("Acquires a token successfully via an Azure Arc Managed Identity", () =
// managedIdentityIdParams will be omitted for system assigned
},
});
expect(managedIdentityApplication.getManagedIdentitySource()).toBe(
AzureIdentitySdkManagedIdentitySourceNames.AZURE_ARC
);

await expect(
managedIdentityApplication.acquireToken(
Expand All @@ -247,6 +265,9 @@ describe("Acquires a token successfully via an Azure Arc Managed Identity", () =
// managedIdentityIdParams will be omitted for system assigned
},
});
expect(managedIdentityApplication.getManagedIdentitySource()).toBe(
AzureIdentitySdkManagedIdentitySourceNames.AZURE_ARC
);

jest.spyOn(fs, "readFileSync").mockImplementationOnce(() => {
throw new Error();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ import {
} from "../../test_kit/ManagedIdentityTestUtils";
import { AuthenticationResult } from "@azure/msal-common";
import { ManagedIdentityClient } from "../../../src/client/ManagedIdentityClient";
import { ManagedIdentityEnvironmentVariableNames } from "../../../src/utils/Constants";
import {
AzureIdentitySdkManagedIdentitySourceNames,
ManagedIdentityEnvironmentVariableNames,
} from "../../../src/utils/Constants";
import {
ManagedIdentityErrorCodes,
createManagedIdentityError,
Expand Down Expand Up @@ -47,6 +50,9 @@ describe("Acquires a token successfully via an App Service Managed Identity", ()
managedIdentityApplication = new ManagedIdentityApplication(
systemAssignedConfig
);
expect(managedIdentityApplication.getManagedIdentitySource()).toBe(
AzureIdentitySdkManagedIdentitySourceNames.CLOUD_SHELL
);
});

test("acquires a token", async () => {
Expand Down Expand Up @@ -93,6 +99,9 @@ describe("Acquires a token successfully via an App Service Managed Identity", ()

const managedIdentityApplication: ManagedIdentityApplication =
new ManagedIdentityApplication(userAssignedClientIdConfig);
expect(managedIdentityApplication.getManagedIdentitySource()).toBe(
AzureIdentitySdkManagedIdentitySourceNames.CLOUD_SHELL
);

await expect(
managedIdentityApplication.acquireToken(
Expand Down
Loading

0 comments on commit aa22e65

Please sign in to comment.