diff --git a/src/commands/login.ts b/src/commands/login.ts index 523eeb6..886203f 100644 --- a/src/commands/login.ts +++ b/src/commands/login.ts @@ -8,7 +8,12 @@ This file is part of @p0security/cli You should have received a copy of the GNU General Public License along with @p0security/cli. If not, see . **/ -import { IDENTITY_FILE_PATH, authenticate } from "../drivers/auth"; +import { + IDENTITY_CACHE_PATH, + IDENTITY_FILE_PATH, + authenticate, + loadCredentials, +} from "../drivers/auth"; import { doc, guard } from "../drivers/firestore"; import { print2 } from "../drivers/stdio"; import { pluginLoginMap } from "../plugins/login"; @@ -40,6 +45,12 @@ export const login = async ( if (!loginFn) throw "Unsupported login for your organization"; const tokenResponse = await loginFn(orgWithSlug); + // if the user changed their org, clear any cached identities this prevents + // commands like `aws assume role` from using the old identities + const currentIdentity = await loadCredentials().catch(() => undefined); + if (currentIdentity?.org.slug !== args.org) { + await clearIdentityCache(); + } await writeIdentity(orgWithSlug, tokenResponse); // validate auth @@ -69,6 +80,16 @@ const writeIdentity = async (org: OrgData, credential: TokenResponse) => { ); }; +const clearIdentityCache = async () => { + try { + // check to see if the directory exists before trying to remove it + await fs.access(IDENTITY_CACHE_PATH); + await fs.rm(IDENTITY_CACHE_PATH, { recursive: true }); + } catch { + return; + } +}; + export const loginCommand = (yargs: yargs.Argv) => yargs.command<{ org: string }>( "login ", diff --git a/src/drivers/auth.ts b/src/drivers/auth.ts index 2d30bab..9173d0c 100644 --- a/src/drivers/auth.ts +++ b/src/drivers/auth.ts @@ -22,6 +22,10 @@ import * as fs from "fs/promises"; import * as path from "path"; export const IDENTITY_FILE_PATH = path.join(P0_PATH, "identity.json"); +export const IDENTITY_CACHE_PATH = path.join( + path.dirname(IDENTITY_FILE_PATH), + "cache" +); export const cached = async ( name: string, @@ -29,11 +33,10 @@ export const cached = async ( options: { duration: number }, hasExpired?: (data: T) => boolean ): Promise => { - const cachePath = path.join(path.dirname(IDENTITY_FILE_PATH), "cache"); // Following lines sanitize input // nosemgrep: javascript.lang.security.audit.path-traversal.path-join-resolve-traversal.path-join-resolve-traversal - const loc = path.resolve(path.join(cachePath, `${name}.json`)); - if (!loc.startsWith(cachePath)) { + const loc = path.resolve(path.join(IDENTITY_CACHE_PATH, `${name}.json`)); + if (!loc.startsWith(IDENTITY_CACHE_PATH)) { throw new Error("Illegal path traversal"); }