from loguru import logger import sys import argparse import msal import requests import json from jose import jwt from jose import jws from jose.backends import RSAKey parser = argparse.ArgumentParser( description="get_and_validate - Obtain a token from the Azure Identity Provider and validate it using python-jose" ) parser.add_argument( '--authority', '-a', help="URL to Identity Provider, includes tenant id", required=True ) parser.add_argument( '--client', '-c', help="Client or Application ID", required=True ) parser.add_argument( '--scope', '-s', help="Scope", required=True ) parser.add_argument( '--secret', '-p', help="Client Secret aka password", required=True ) parser.add_argument( '--keyurl', '-k', help="URL where Microsoft publishes public keys", required=False, default="https://login.microsoftonline.com/common/discovery/keys" ) parser.add_argument( '--verbose', '-v', help="Verbose output", action='store_true', default=False ) args = parser.parse_args() authority = args.authority clientId = args.client scope = args.scope secret = args.secret keyUrl = args.keyurl verbose = args.verbose if verbose: logLevel = "DEBUG" else: logLevel = "INFO" logger.remove() logger.add(sys.stdout, level=logLevel) # ------------------------------------------------------------------------------- # --- client side - obtain a token ---------------------------------------------- # ------------------------------------------------------------------------------- # instance for connection to identity provider app = msal.ConfidentialClientApplication( clientId, authority=authority, client_credential=secret ) logger.debug("Identity Provider Instance created") # best practice to acquire token result = None # try cache result = app.acquire_token_silent([scope], account=None) logger.debug(f"Token from cache: {result}") # if not in cache if not result: logger.debug("No token in cache found") result = app.acquire_token_for_client(scopes=[scope]) logger.debug(f"New token object: {result}") token = result['access_token'] logger.info(f"Actual token: {token}") # ------------------------------------------------------------------------------- # --- server side - validate a token -------------------------------------------- # ------------------------------------------------------------------------------- # obtain the public keys keysResponse = requests.get(keyUrl) if keysResponse.status_code != 200: msg = f"Unable to get public keys, code: {keysResponse.status_code}, reason: {keysResponse.reason}" logger.error(f"{msg}") raise Exception(msg) keys = json.loads(keysResponse.text)['keys'] logger.debug(f"{keys=}") # get unverified header from token to find key-id and algorithm unverifiedTokenHeader = jws.get_unverified_header(token) logger.debug(f"{unverifiedTokenHeader=}") signingKeyId = unverifiedTokenHeader['kid'] logger.debug(f"{signingKeyId=}") signingAlgorithm = unverifiedTokenHeader['alg'] logger.debug(f"{signingAlgorithm=}") # get unverified claims from token to find the audience unverifiedClaimsStr = jws.get_unverified_claims(token) unverifiedClaims = json.loads(unverifiedClaimsStr) logger.debug(f"{unverifiedClaims=}") audience = unverifiedClaims['aud'] logger.debug(f"{audience=}") # get signing key from downloaded public keys structure based on key-id of token try: signingKey = [ x for x in keys if x['kid'] == signingKeyId ][0] except IndexError: msg = f"No signing key with id {signingKeyId} found" logger.error("f{msg}") raise Exception(msg) logger.debug(f"{signingKey=}") # check if it is an RSA key if signingKey['kty'] != 'RSA': msg = f"Signing key is not an RSA key but {signingKey['kty']}" logger.error(f"{msg}") raise Exception(msg) # put key into jwk object key = RSAKey(signingKey, signingAlgorithm) logger.debug(f"Public key is {key.to_pem()}") # decode token including validation of signature tokenInfo = jwt.decode(token, key, audience=audience) logger.info(f"{tokenInfo=}") logger.info("Token successfully verified")