authservice/auth.py
2021-09-06 21:23:54 +02:00

372 lines
13 KiB
Python
Executable File

import time
import connexion
from jose import JWTError, jwt, jwe
import json
from jose.exceptions import ExpiredSignatureError
import werkzeug
import os
import psycopg2
from collections import namedtuple
from pbkdf2 import crypt
from loguru import logger
import configparser
import random
import string
DB_USER = ""
DB_PASS = ""
DB_HOST = ""
DB_NAME = ""
JWT_ISSUER = ""
try:
DB_USER = os.environ["DB_USER"]
DB_PASS = os.environ["DB_PASS"]
DB_HOST = os.environ["DB_HOST"]
DB_NAME = os.environ["DB_NAME"]
JWT_ISSUER = os.environ["JWT_ISSUER"]
except KeyError:
config = configparser.ConfigParser()
config.read('./config/config.ini')
DB_USER = config["database"]["user"]
DB_PASS = config["database"]["pass"]
DB_HOST = config["database"]["host"]
DB_NAME = config["database"]["name"]
JWT_ISSUER = config["jwt"]["issuer"]
class NoUserException(Exception):
pass
class RefreshTokenExpiredException(Exception):
pass
class NoTokenException(Exception):
pass
class NoValidTokenException(Exception):
pass
class ManyUsersException(Exception):
pass
class ManyTokensException(Exception):
pass
class PasswordMismatchException(Exception):
pass
class RefreshTokenExpiredException(Exception):
pass
UserEntry = namedtuple('UserEntry', ['id', 'login', 'pwhash', 'expiry', 'claims'])
RefreshTokenEntry = namedtuple('RefreshTokenEntry', ['id', 'salt', 'login', 'app', 'expiry'])
JWT_PRIV_KEY = ""
try:
JWT_PRIV_KEY = os.environ["JWT_PRIV_KEY"]
except KeyError:
with open('./config/authservice.key', 'r') as f:
JWT_PRIV_KEY = f.read()
JWT_PUB_KEY = ""
try:
JWT_PUB_KEY = os.environ["JWT_PUB_KEY"]
except KeyError:
with open('./config/authservice.pub', 'r') as f:
JWT_PUB_KEY = f.read()
def getUserEntryFromDB(application: str, login: str):
conn = None
cur = None
try:
conn = psycopg2.connect(user = DB_USER, password = DB_PASS,
host = DB_HOST, database = DB_NAME)
conn.autocommit = False
userObj = None
with conn.cursor() as cur:
cur.execute("SELECT id, pwhash, expiry FROM user_application_v" +
" WHERE application = %s AND login = %s",
(application, login))
userObj = cur.fetchone()
logger.debug("userObj: {}".format(userObj))
if not userObj:
raise NoUserException()
invObj = cur.fetchone()
if invObj:
raise ManyUsersException()
claims = {}
with conn.cursor() as cur:
cur.execute('SELECT key, value FROM claims_for_user_v where "user" = %s and application = %s',
(userObj[0], application))
for claimObj in cur:
logger.debug("add claim {} -> {}".format(claimObj[0], claimObj[1]))
if claimObj[0] in claims:
if isinstance(claims[claimObj[0]], list):
claims[claimObj[0]].append(claimObj[1])
else:
claims[claimObj[0]] = [ claims[claimObj[0]] ]
claims[claimObj[0]].append(claimObj[1])
else:
claims[claimObj[0]] = claimObj[1]
userEntry = UserEntry(id=userObj[0], login=login, pwhash=userObj[1], expiry=userObj[2], claims=claims)
return userEntry
except psycopg2.Error as err:
raise Exception("Error when connecting to database: {}".format(err))
finally:
if conn:
conn.close()
def getRefreshTokenFromDB(application, login):
conn = None
cur = None
try:
conn = psycopg2.connect(user = DB_USER, password = DB_PASS,
host = DB_HOST, database = DB_NAME)
conn.autocommit = False
with conn:
with conn.cursor() as cur:
cur.execute('SELECT u.id, u.expiry, a.name FROM user_t u, application_t a, user_application_mapping_t m' +
' WHERE u.login = %s AND ' +
' a.name = %s AND ' +
' a.id = m.application AND u.id = m."user"',
(login, application))
userObj = cur.fetchone()
logger.debug("userObj: {}".format(userObj))
if not userObj:
raise NoUserException()
invObj = cur.fetchone()
if invObj:
raise ManyUsersException()
with conn.cursor() as cur:
salt = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(64))
cur.execute('INSERT INTO token_t ("user", salt, expiry) VALUES (%s, %s, %s) RETURNING id',
(userObj[0], salt, userObj[1]))
tokenObj = cur.fetchone()
logger.debug("tokenObj: {}".format(tokenObj))
if not tokenObj:
raise NoTokenException()
invObj = cur.fetchone()
if invObj:
raise ManyTokensException()
refreshTokenEntry = RefreshTokenEntry(id=tokenObj[0], salt=salt, login=login, app=userObj[2], expiry=userObj[1])
return refreshTokenEntry
except psycopg2.Error as err:
raise Exception("Error when connecting to database: {}".format(err))
finally:
if conn:
conn.close()
def getUserEntry(application, login, password):
userEntry = getUserEntryFromDB(application, login)
if userEntry.pwhash != crypt(password, userEntry.pwhash):
raise PasswordMismatchException()
return userEntry
def generateToken(func, **args):
try:
body = args["body"]
application = ""
login = ""
password = ""
if (("application" in body) and
("login" in body) and
("password" in body)):
application = body["application"]
login = body["login"]
password = body["password"]
elif ("encAleTuple" in body):
clearContent = jwe.decrypt(body["encAleTuple"], JWT_PRIV_KEY)
clearObj = json.loads(clearContent)
application = clearObj["application"]
login = clearObj["login"]
password = clearObj["password"]
else:
raise KeyError("Neither application, login and password nor encAleTuple given")
logger.debug(f"Tuple: {application} {login} {password}")
return func(application, login, password)
except NoTokenException:
logger.error("no token created")
raise werkzeug.exceptions.Unauthorized()
except ManyTokensException:
logger.error("too many tokens created")
raise werkzeug.exceptions.Unauthorized()
except NoUserException:
logger.error("no user found, login or application wrong")
raise werkzeug.exceptions.Unauthorized()
except ManyUsersException:
logger.error("too many users found")
raise werkzeug.exceptions.Unauthorized()
except PasswordMismatchException:
logger.error("wrong password")
raise werkzeug.exceptions.Unauthorized()
except KeyError:
logger.error("application, login or password missing")
raise werkzeug.exceptions.Unauthorized()
except Exception as e:
logger.error("unspecific exception: {}".format(str(e)))
raise werkzeug.exceptions.Unauthorized()
def _makeSimpleToken(application, login, password, refresh=False):
userEntry = getUserEntry(application, login, password) if not refresh else getUserEntryFromDB(application, login)
timestamp = int(time.time())
payload = {
"iss": JWT_ISSUER,
"iat": int(timestamp),
"exp": int(timestamp + userEntry.expiry),
"sub": str(userEntry.id),
"aud": application
}
logger.debug("claims: {}".format(userEntry.claims))
for claim in userEntry.claims.items():
logger.debug("add claim {}".format(claim))
payload[claim[0]] = claim[1]
return jwt.encode(payload, JWT_PRIV_KEY, algorithm='RS256')
def _makeRefreshToken(application, login, password):
refreshTokenEntry = getRefreshTokenFromDB(application, login)
timestamp = int(time.time())
payload = {
"iss": JWT_ISSUER,
"iat": int(timestamp),
"exp": int(timestamp + refreshTokenEntry.expiry),
"sub": str(refreshTokenEntry.login),
"xap": str(refreshTokenEntry.app),
"xid": str(refreshTokenEntry.id),
"xal": str(refreshTokenEntry.salt)
}
return jwt.encode(payload, JWT_PRIV_KEY, algorithm='RS256')
def _makeRefreshableTokens(application, login, password):
authToken = _makeSimpleToken(application, login, password)
refreshToken = _makeRefreshToken(application, login, password)
return {
"authToken": authToken,
"refreshToken": refreshToken
}
def generateSimpleToken(**args):
return generateToken(_makeSimpleToken, **args)
def generateRefreshableTokens(**args):
return generateToken(_makeRefreshableTokens, **args)
def getPubKey():
return JWT_PUB_KEY
def decodeToken(token):
try:
return jwt.decode(token, JWT_PUB_KEY, audience="test")
except JWTError as e:
logger.error("{}".format(e))
raise werkzeug.exceptions.Unauthorized()
def testToken(user, token_info):
return {
"message": f"You are user_id {user} and the provided token has been signed by this issuers. Fine.",
"details": token_info
}
def checkAndInvalidateRefreshToken(login, xid, xal):
conn = None
cur = None
try:
conn = psycopg2.connect(user = DB_USER, password = DB_PASS,
host = DB_HOST, database = DB_NAME)
conn.autocommit = False
with conn:
with conn.cursor() as cur:
cur.execute('SELECT t.id FROM token_t t, user_t u' +
' WHERE t.id = %s AND ' +
' t.salt = %s AND ' +
' t."user" = u.id AND ' +
' u.login = %s',
(xid, xal, login))
tokenObj = cur.fetchone()
logger.debug("tokenObj: {}".format(tokenObj))
if not tokenObj:
raise NoTokenException()
invObj = cur.fetchone()
if invObj:
raise ManyTokensException()
with conn.cursor() as cur:
cur.execute('UPDATE token_t SET used = used + 1 WHERE id = %s',
[ xid ])
except psycopg2.Error as err:
raise Exception("Error when connecting to database: {}".format(err))
finally:
if conn:
conn.close()
def refreshTokens(**args):
try:
refreshToken = args["body"]
refreshTokenObj = jwt.decode(refreshToken, JWT_PUB_KEY)
logger.info(str(refreshTokenObj))
if refreshTokenObj["exp"] < int(time.time()):
raise RefreshTokenExpiredException()
checkAndInvalidateRefreshToken(refreshTokenObj["sub"], refreshTokenObj["xid"], refreshTokenObj["xal"])
authToken = _makeSimpleToken(refreshTokenObj["xap"], refreshTokenObj["sub"], "", refresh=True)
refreshToken = _makeRefreshToken(refreshTokenObj["xap"], refreshTokenObj["sub"], "")
return {
"authToken": authToken,
"refreshToken": refreshToken
}
except JWTError as e:
logger.error("jwt.decode failed: {}".format(e))
raise werkzeug.exceptions.Unauthorized()
except RefreshTokenExpiredException:
logger.error("refresh token expired")
raise werkzeug.exceptions.Unauthorized()
except NoTokenException:
logger.error("no token created/found")
raise werkzeug.exceptions.Unauthorized()
except NoValidTokenException:
logger.error("no valid token found")
raise werkzeug.exceptions.Unauthorized()
except ManyTokensException:
logger.error("too many tokens created/found")
raise werkzeug.exceptions.Unauthorized()
except NoUserException:
logger.error("no user found, login or application wrong")
raise werkzeug.exceptions.Unauthorized()
except ManyUsersException:
logger.error("too many users found")
raise werkzeug.exceptions.Unauthorized()
except PasswordMismatchException:
logger.error("wrong password")
raise werkzeug.exceptions.Unauthorized()
except KeyError:
logger.error("application, login or password missing")
raise werkzeug.exceptions.Unauthorized()
except Exception as e:
logger.error("unspecific exception: {}".format(str(e)))
raise werkzeug.exceptions.Unauthorized()