154 lines
4.1 KiB
Python
154 lines
4.1 KiB
Python
from loguru import logger
|
|
import sys
|
|
import argparse
|
|
|
|
import msal
|
|
|
|
import requests
|
|
import json
|
|
from jose import jwt
|
|
from jose import jws
|
|
from jose.backends import RSAKey
|
|
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="get_and_validate - Obtain a token from the Azure Identity Provider and validate it using python-jose"
|
|
)
|
|
parser.add_argument(
|
|
'--authority', '-a',
|
|
help="URL to Identity Provider, includes tenant id",
|
|
required=True
|
|
)
|
|
parser.add_argument(
|
|
'--client', '-c',
|
|
help="Client or Application ID",
|
|
required=True
|
|
)
|
|
parser.add_argument(
|
|
'--scope', '-s',
|
|
help="Scope",
|
|
required=True
|
|
)
|
|
parser.add_argument(
|
|
'--secret', '-p',
|
|
help="Client Secret aka password",
|
|
required=True
|
|
)
|
|
parser.add_argument(
|
|
'--keyurl', '-k',
|
|
help="URL where Microsoft publishes public keys",
|
|
required=False,
|
|
default="https://login.microsoftonline.com/common/discovery/keys"
|
|
)
|
|
parser.add_argument(
|
|
'--verbose', '-v',
|
|
help="Verbose output",
|
|
action='store_true',
|
|
default=False
|
|
)
|
|
args = parser.parse_args()
|
|
authority = args.authority
|
|
clientId = args.client
|
|
scope = args.scope
|
|
secret = args.secret
|
|
keyUrl = args.keyurl
|
|
verbose = args.verbose
|
|
|
|
|
|
if verbose:
|
|
logLevel = "DEBUG"
|
|
else:
|
|
logLevel = "INFO"
|
|
|
|
logger.remove()
|
|
logger.add(sys.stdout, level=logLevel)
|
|
|
|
|
|
|
|
# -------------------------------------------------------------------------------
|
|
# --- client side - obtain a token ----------------------------------------------
|
|
# -------------------------------------------------------------------------------
|
|
# instance for connection to identity provider
|
|
app = msal.ConfidentialClientApplication(
|
|
clientId,
|
|
authority=authority,
|
|
client_credential=secret
|
|
)
|
|
logger.debug("Identity Provider Instance created")
|
|
|
|
# best practice to acquire token
|
|
result = None
|
|
|
|
# try cache
|
|
result = app.acquire_token_silent([scope], account=None)
|
|
logger.debug(f"Token from cache: {result}")
|
|
|
|
# if not in cache
|
|
if not result:
|
|
logger.debug("No token in cache found")
|
|
result = app.acquire_token_for_client(scopes=[scope])
|
|
logger.debug(f"New token object: {result}")
|
|
|
|
token = result['access_token']
|
|
logger.info(f"Actual token: {token}")
|
|
|
|
|
|
# -------------------------------------------------------------------------------
|
|
# --- server side - validate a token --------------------------------------------
|
|
# -------------------------------------------------------------------------------
|
|
# obtain the public keys
|
|
keysResponse = requests.get(keyUrl)
|
|
if keysResponse.status_code != 200:
|
|
msg = f"Unable to get public keys, code: {keysResponse.status_code}, reason: {keysResponse.reason}"
|
|
logger.error(f"{msg}")
|
|
raise Exception(msg)
|
|
|
|
keys = json.loads(keysResponse.text)['keys']
|
|
logger.debug(f"{keys=}")
|
|
|
|
# get unverified header from token to find key-id and algorithm
|
|
unverifiedTokenHeader = jws.get_unverified_header(token)
|
|
logger.debug(f"{unverifiedTokenHeader=}")
|
|
signingKeyId = unverifiedTokenHeader['kid']
|
|
logger.debug(f"{signingKeyId=}")
|
|
signingAlgorithm = unverifiedTokenHeader['alg']
|
|
logger.debug(f"{signingAlgorithm=}")
|
|
|
|
# get unverified claims from token to find the audience
|
|
unverifiedClaimsStr = jws.get_unverified_claims(token)
|
|
unverifiedClaims = json.loads(unverifiedClaimsStr)
|
|
logger.debug(f"{unverifiedClaims=}")
|
|
audience = unverifiedClaims['aud']
|
|
logger.debug(f"{audience=}")
|
|
|
|
# get signing key from downloaded public keys structure based on key-id of token
|
|
try:
|
|
signingKey = [ x for x in keys if x['kid'] == signingKeyId ][0]
|
|
except IndexError:
|
|
msg = f"No signing key with id {signingKeyId} found"
|
|
logger.error("f{msg}")
|
|
raise Exception(msg)
|
|
|
|
logger.debug(f"{signingKey=}")
|
|
|
|
# check if it is an RSA key
|
|
if signingKey['kty'] != 'RSA':
|
|
msg = f"Signing key is not an RSA key but {signingKey['kty']}"
|
|
logger.error(f"{msg}")
|
|
raise Exception(msg)
|
|
|
|
# put key into jwk object
|
|
key = RSAKey(signingKey, signingAlgorithm)
|
|
logger.debug(f"Public key is {key.to_pem()}")
|
|
|
|
# decode token including validation of signature
|
|
tokenInfo = jwt.decode(token, key, audience=audience)
|
|
logger.info(f"{tokenInfo=}")
|
|
|
|
logger.info("Token successfully verified")
|
|
|
|
|
|
|
|
|
|
|