8 Commits

2 changed files with 68 additions and 42 deletions

View File

@ -1,6 +1,6 @@
#!/usr/bin/python #!/usr/bin/python
import mariadb import psycopg2
from pbkdf2 import crypt from pbkdf2 import crypt
import argparse import argparse
import os import os
@ -23,30 +23,26 @@ password = args.password
application = args.application application = args.application
DB_USER = os.environ["DB_USER"] DB_NAME = "authservice"
DB_PASS = os.environ["DB_PASS"]
DB_HOST = os.environ["DB_HOST"]
DB_NAME = os.environ["DB_NAME"]
pwhash = crypt(password, iterations=100000) pwhash = crypt(password, iterations=100000)
conn = None conn = None
cur = None cur = None
try: try:
conn = mariadb.connect(user = DB_USER, password = DB_PASS, conn = psycopg2.connect(database = DB_NAME)
host = DB_HOST, database = DB_NAME)
conn.autocommit = False conn.autocommit = False
cur = conn.cursor() cur = conn.cursor()
cur.execute(""" cur.execute("""
INSERT INTO users (login, pwhash) INSERT INTO user_t (login, pwhash)
VALUES(?, ?) VALUES(%s, %s)
""", [user, pwhash]) """, [user, pwhash])
cur.execute(""" cur.execute("""
INSERT INTO user_applications_mapping (application, user) INSERT INTO user_application_mapping_t (application,"user")
VALUES( VALUES(
(SELECT id FROM applications WHERE name = ?), (SELECT id FROM application_t WHERE name = %s),
(SELECT id FROM users WHERE login = ?) (SELECT id FROM user_t WHERE login = %s)
) )
""", [application, user]) """, [application, user])
conn.commit() conn.commit()

90
auth.py
View File

@ -12,6 +12,7 @@ from loguru import logger
import configparser import configparser
import random import random
import string import string
from flask import request
DB_USER = "" DB_USER = ""
@ -38,6 +39,9 @@ except KeyError:
class NoUserException(Exception): class NoUserException(Exception):
pass pass
class RefreshTokenExpiredException(Exception):
pass
class NoTokenException(Exception): class NoTokenException(Exception):
pass pass
@ -120,7 +124,7 @@ def getUserEntryFromDB(application: str, login: str):
if conn: if conn:
conn.close() conn.close()
def getRefreshTokenFromDB(application, login): def getRefreshTokenFromDB(application, login, httpClientIp):
conn = None conn = None
cur = None cur = None
try: try:
@ -145,8 +149,8 @@ def getRefreshTokenFromDB(application, login):
with conn.cursor() as cur: with conn.cursor() as cur:
salt = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(64)) salt = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(64))
cur.execute('INSERT INTO token_t ("user", salt) VALUES (%s, %s) RETURNING id', cur.execute('INSERT INTO token_t ("user", salt, expiry, client_ip) VALUES (%s, %s, %s, %s) RETURNING id',
(userObj[0], salt)) (userObj[0], salt, userObj[1], httpClientIp))
tokenObj = cur.fetchone() tokenObj = cur.fetchone()
logger.debug("tokenObj: {}".format(tokenObj)) logger.debug("tokenObj: {}".format(tokenObj))
if not tokenObj: if not tokenObj:
@ -192,9 +196,14 @@ def generateToken(func, **args):
else: else:
raise KeyError("Neither application, login and password nor encAleTuple given") raise KeyError("Neither application, login and password nor encAleTuple given")
logger.debug(f"Tuple: {application} {login} {password}") if request.headers.getlist("X-Forwarded-For"):
httpClientIp = request.headers.getlist("X-Forwarded-For")[0]
else:
httpClientIp = request.remote_addr
return func(application, login, password) logger.debug(f"Tuple: {application} {login} {password} {httpClientIp}")
return func(application, login, password, httpClientIp)
except NoTokenException: except NoTokenException:
logger.error("no token created") logger.error("no token created")
raise werkzeug.exceptions.Unauthorized() raise werkzeug.exceptions.Unauthorized()
@ -218,7 +227,7 @@ def generateToken(func, **args):
raise werkzeug.exceptions.Unauthorized() raise werkzeug.exceptions.Unauthorized()
def _makeSimpleToken(application, login, password, refresh=False): def _makeSimpleToken(application, login, password, httpClientIp, refresh=False):
userEntry = getUserEntry(application, login, password) if not refresh else getUserEntryFromDB(application, login) userEntry = getUserEntry(application, login, password) if not refresh else getUserEntryFromDB(application, login)
timestamp = int(time.time()) timestamp = int(time.time())
@ -236,8 +245,8 @@ def _makeSimpleToken(application, login, password, refresh=False):
return jwt.encode(payload, JWT_PRIV_KEY, algorithm='RS256') return jwt.encode(payload, JWT_PRIV_KEY, algorithm='RS256')
def _makeRefreshToken(application, login, password): def _makeRefreshToken(application, login, password, httpClientIp):
refreshTokenEntry = getRefreshTokenFromDB(application, login) refreshTokenEntry = getRefreshTokenFromDB(application, login, httpClientIp)
timestamp = int(time.time()) timestamp = int(time.time())
payload = { payload = {
@ -252,17 +261,17 @@ def _makeRefreshToken(application, login, password):
return jwt.encode(payload, JWT_PRIV_KEY, algorithm='RS256') return jwt.encode(payload, JWT_PRIV_KEY, algorithm='RS256')
def _makeRefreshableTokens(application, login, password): def _makeRefreshableTokens(application, login, password, httpClientIp):
authToken = _makeSimpleToken(application, login, password) authToken = _makeSimpleToken(application, login, password, httpClientIp)
refreshToken = _makeRefreshToken(application, login, password) refreshToken = _makeRefreshToken(application, login, password, httpClientIp)
return { return {
"authToken": authToken, "authToken": authToken,
"refreshToken": refreshToken "refreshToken": refreshToken
} }
def generateSimpleToken(**args): def generateSimpleToken(**args):
return generateToken(_makeSimpleToken, **args)
return generateToken(_makeSimpleToken, **args)
def generateRefreshableTokens(**args): def generateRefreshableTokens(**args):
return generateToken(_makeRefreshableTokens, **args) return generateToken(_makeRefreshableTokens, **args)
@ -283,34 +292,44 @@ def testToken(user, token_info):
} }
def checkAndInvalidateRefreshToken(login, xid, xal): def checkAndInvalidateRefreshToken(login, xid, xal, httpClientIp):
conn = None
cur = None
try: try:
validTokenFound = False
conn = psycopg2.connect(user = DB_USER, password = DB_PASS, conn = psycopg2.connect(user = DB_USER, password = DB_PASS,
host = DB_HOST, database = DB_NAME) host = DB_HOST, database = DB_NAME)
conn.autocommit = False conn.autocommit = False
with conn: with conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute('SELECT t.id FROM token_t t, user_t u' + cur.execute('SELECT t.id, t.client_ip FROM token_t t, user_t u' +
' WHERE t.id = %s AND ' + ' WHERE t.valid = true AND ' +
' t.id = %s AND ' +
' t.salt = %s AND ' + ' t.salt = %s AND ' +
' t."user" = u.id AND ' + ' t."user" = u.id AND ' +
' u.login = %s AND ' + ' u.login = %s',
' t.valid = true',
(xid, xal, login)) (xid, xal, login))
tokenObj = cur.fetchone() tokenObj = cur.fetchone()
logger.debug("tokenObj: {}".format(tokenObj)) logger.debug("tokenObj: {}".format(tokenObj))
if not tokenObj: if not tokenObj:
raise NoValidTokenException() raise NoTokenException()
invObj = cur.fetchone() invObj = cur.fetchone()
if invObj: if invObj:
raise ManyTokensException() raise ManyTokensException()
with conn.cursor() as cur: if (tokenObj[1] == httpClientIp):
cur.execute('UPDATE token_t SET valid = false WHERE id = %s', with conn.cursor() as cur:
[ xid ]) cur.execute('UPDATE token_t SET used = used + 1 WHERE id = %s',
[ xid ])
validTokenFound = True
else:
logger.warning(f"Client IP in token {tokenObj[1]} and current one {httpClientIp} does not match")
with conn.cursor() as cur:
cur.execute('UPDATE token_t SET valid = false WHERE id = %s',
[ xid ])
if (not validTokenFound):
raise NoValidTokenException()
except psycopg2.Error as err: except psycopg2.Error as err:
raise Exception("Error when connecting to database: {}".format(err)) raise Exception("Error when connecting to database: {}".format(err))
finally: finally:
@ -324,10 +343,18 @@ def refreshTokens(**args):
refreshTokenObj = jwt.decode(refreshToken, JWT_PUB_KEY) refreshTokenObj = jwt.decode(refreshToken, JWT_PUB_KEY)
logger.info(str(refreshTokenObj)) logger.info(str(refreshTokenObj))
checkAndInvalidateRefreshToken(refreshTokenObj["sub"], refreshTokenObj["xid"], refreshTokenObj["xal"]) if request.headers.getlist("X-Forwarded-For"):
httpClientIp = request.headers.getlist("X-Forwarded-For")[0]
else:
httpClientIp = request.remote_addr
authToken = _makeSimpleToken(refreshTokenObj["xap"], refreshTokenObj["sub"], "", refresh=True) if refreshTokenObj["exp"] < int(time.time()):
refreshToken = _makeRefreshToken(refreshTokenObj["xap"], refreshTokenObj["sub"], "") raise RefreshTokenExpiredException()
checkAndInvalidateRefreshToken(refreshTokenObj["sub"], refreshTokenObj["xid"], refreshTokenObj["xal"], httpClientIp)
authToken = _makeSimpleToken(refreshTokenObj["xap"], refreshTokenObj["sub"], "", httpClientIp, refresh=True)
refreshToken = _makeRefreshToken(refreshTokenObj["xap"], refreshTokenObj["sub"], "", httpClientIp)
return { return {
"authToken": authToken, "authToken": authToken,
"refreshToken": refreshToken "refreshToken": refreshToken
@ -335,8 +362,11 @@ def refreshTokens(**args):
except JWTError as e: except JWTError as e:
logger.error("jwt.decode failed: {}".format(e)) logger.error("jwt.decode failed: {}".format(e))
raise werkzeug.exceptions.Unauthorized() raise werkzeug.exceptions.Unauthorized()
except RefreshTokenExpiredException:
logger.error("refresh token expired")
raise werkzeug.exceptions.Unauthorized()
except NoTokenException: except NoTokenException:
logger.error("no token created") logger.error("no token created/found")
raise werkzeug.exceptions.Unauthorized() raise werkzeug.exceptions.Unauthorized()
except NoValidTokenException: except NoValidTokenException:
logger.error("no valid token found") logger.error("no valid token found")
@ -356,8 +386,8 @@ def refreshTokens(**args):
except KeyError: except KeyError:
logger.error("application, login or password missing") logger.error("application, login or password missing")
raise werkzeug.exceptions.Unauthorized() raise werkzeug.exceptions.Unauthorized()
#except Exception as e: except Exception as e:
# logger.error("unspecific exception: {}".format(str(e))) logger.error("unspecific exception: {}".format(str(e)))
# raise werkzeug.exceptions.Unauthorized() raise werkzeug.exceptions.Unauthorized()