From 317a6c720d19a1bee23ccc096b594242186b4a04 Mon Sep 17 00:00:00 2001 From: Dominique Date: Wed, 29 Nov 2023 13:11:12 +0100 Subject: [PATCH] Dom debug (#130) * Fixed #127 --- ptmd/api/queries/users.py | 21 +++++++--- ptmd/database/const.py | 1 + ptmd/database/models/user.py | 7 ++++ ptmd/database/queries/users.py | 5 ++- ptmd/exceptions.py | 41 +++++++++++++++++++ tests/test_api/test_queries/test_users.py | 37 ++++++++++++++++- tests/test_database/test_models/test_user.py | 18 ++++++-- .../test_database/test_queries/test_users.py | 20 ++++++++- tests/test_exceptions.py | 11 +++++ 9 files changed, 147 insertions(+), 14 deletions(-) create mode 100644 ptmd/exceptions.py create mode 100644 tests/test_exceptions.py diff --git a/ptmd/api/queries/users.py b/ptmd/api/queries/users.py index 5f86e2f..473ab2d 100644 --- a/ptmd/api/queries/users.py +++ b/ptmd/api/queries/users.py @@ -14,6 +14,7 @@ from ptmd.config import session from ptmd.const import CREATE_USER_SCHEMA_PATH from ptmd.database import login_user, get_token, User, TokenBlocklist, Token, Organisation +from ptmd.exceptions import PasswordPolicyError, TokenInvalidError, TokenExpiredError from .utils import check_role @@ -76,11 +77,19 @@ def change_password() -> tuple[Response, int]: if new_password != repeat_password: return jsonify({"msg": "Passwords do not match"}), 400 + if password == new_password: + return jsonify({"msg": "New password cannot be the same as the old one"}), 400 + user: User = User.query.filter(User.id == get_jwt()['sub']).first() - changed: bool = user.change_password(old_password=password, new_password=new_password) - if not changed: - return jsonify({"msg": "Wrong password"}), 400 - return jsonify({"msg": "Password changed successfully"}), 200 + try: + changed: bool = user.change_password(old_password=password, new_password=new_password) + if not changed: + return jsonify({"msg": "Wrong password"}), 400 + return jsonify({"msg": "Password changed successfully"}), 200 if changed else jsonify() + except PasswordPolicyError as e: + return jsonify({"msg": str(e)}), 400 + except Exception: + return jsonify({"msg": "An unexpected error occurred"}), 500 @check_role(role='disabled') @@ -188,8 +197,10 @@ def reset_password(token: str) -> tuple[Response, int]: user.set_password(password) session.delete(reset_token_from_db) # type: ignore return jsonify({"msg": "Password changed successfully"}), 200 - except Exception as e: + except (PasswordPolicyError, TokenInvalidError, TokenExpiredError) as e: return jsonify({"msg": str(e)}), 400 + except Exception: + return jsonify({"msg": "An unexpected error occurred"}), 500 @check_role(role='admin') diff --git a/ptmd/database/const.py b/ptmd/database/const.py index 0fdae3b..b378ff9 100644 --- a/ptmd/database/const.py +++ b/ptmd/database/const.py @@ -7,3 +7,4 @@ SQLALCHEMY_DATABASE_URI: str = DOT_ENV_CONFIG['SQLALCHEMY_DATABASE_URL'] SQLALCHEMY_SECRET_KEY: str = DOT_ENV_CONFIG['SQLALCHEMY_SECRET_KEY'] +PASSWORD_POLICY: str = "^(?=.*?[A-Z])(?=.*?[a-z])(?=.*?[0-9])(?=.*?[#?!@$%^&*-]).{8,20}$" diff --git a/ptmd/database/models/user.py b/ptmd/database/models/user.py index 7cb14cd..e44c400 100644 --- a/ptmd/database/models/user.py +++ b/ptmd/database/models/user.py @@ -2,10 +2,13 @@ """ from __future__ import annotations from typing import Generator +from re import match from passlib.hash import bcrypt from ptmd.config import Base, db, session from ptmd.const import ROLES +from ptmd.exceptions import PasswordPolicyError +from ptmd.database.const import PASSWORD_POLICY from ptmd.database.models.token import Token from ptmd.lib.email import send_validation_mail, send_validated_account_mail @@ -97,7 +100,11 @@ def set_password(self, password: str) -> None: """ Set the user password. Helper function to avoid code repetition. :param password: the new password + + :raises PasswordPolicyError: if the password does not match the password policy """ + if not match(PASSWORD_POLICY, password): + raise PasswordPolicyError() self.password = bcrypt.hash(password) session.commit() diff --git a/ptmd/database/queries/users.py b/ptmd/database/queries/users.py index f4e9139..cc201b0 100644 --- a/ptmd/database/queries/users.py +++ b/ptmd/database/queries/users.py @@ -7,6 +7,7 @@ from ptmd.config import session from ptmd.logger import LOGGER +from ptmd.exceptions import TokenExpiredError, TokenInvalidError from ptmd.database.models import User, Token @@ -52,7 +53,7 @@ def get_token(token: str) -> Token: """ token_from_db: Token = Token.query.filter(Token.token == token).first() if token_from_db is None: - raise Exception("Invalid token") + raise TokenInvalidError if token_from_db.expires_on < datetime.now(token_from_db.expires_on.tzinfo): - raise Exception("Token expired") + raise TokenExpiredError return token_from_db diff --git a/ptmd/exceptions.py b/ptmd/exceptions.py new file mode 100644 index 0000000..413225c --- /dev/null +++ b/ptmd/exceptions.py @@ -0,0 +1,41 @@ +""" Custom exceptions for the ptmd package """ +from __future__ import annotations +from abc import ABC + + +class APIError(Exception, ABC): + """ Exception raised when an API error occurs. This is an abstract class, do not use directly. """ + + def __init__(self) -> None: + """ Constructor, do not use """ + self.message: str | None = None + raise SyntaxError("Cannot instantiate abstract class APIError") + + def __str__(self) -> str: + """ String representation of the exception """ + return self.message or "" + + +class PasswordPolicyError(APIError): + """ Exception raised when a password does not meet the password policy """ + + def __init__(self) -> None: + """ Constructor """ + self.message: str = "Password must be between 8 and 20 characters long, contain at least one uppercase " \ + "letter, one lowercase letter, one number and one special character." + + +class TokenExpiredError(APIError): + """ Exception raised when a token is expired """ + + def __init__(self) -> None: + """ Constructor """ + self.message: str = "Token expired" + + +class TokenInvalidError(APIError): + """ Exception raised when a token is invalid """ + + def __init__(self) -> None: + """ Constructor """ + self.message: str = "Invalid token" diff --git a/tests/test_api/test_queries/test_users.py b/tests/test_api/test_queries/test_users.py index ce02eb6..9f223b4 100644 --- a/tests/test_api/test_queries/test_users.py +++ b/tests/test_api/test_queries/test_users.py @@ -6,6 +6,7 @@ from sqlalchemy.exc import IntegrityError from ptmd.api import app +from ptmd.exceptions import PasswordPolicyError HEADERS = {'Content-Type': 'application/json'} @@ -110,6 +111,13 @@ def test_change_pwd(self, mock_user, mock_jwt, mock_session, self.assertEqual(created_user.json, {'msg': 'Passwords do not match'}) user_data['confirm_password'] = '1234' + created_user = client.put('/api/users', + headers={'Authorization': f'Bearer {123}', **HEADERS}, + data=dumps(user_data)) + self.assertEqual(created_user.json, {'msg': 'New password cannot be the same as the old one'}) + + user_data['confirm_password'] = '666' + user_data['new_password'] = '666' mock_user.query.filter().first().change_password.return_value = False created_user = client.put('/api/users', headers={'Authorization': f'Bearer {123}', **HEADERS}, @@ -128,6 +136,23 @@ def test_change_pwd(self, mock_user, mock_jwt, mock_session, data=dumps(user_data)) self.assertEqual(created_user.json, {'msg': 'You are not authorized to access this route'}) + mock_user.query.filter().first().change_password.side_effect = PasswordPolicyError() + mock_get_current_user().role = 'admin' + created_user = client.put('/api/users', + headers={'Authorization': f'Bearer {123}', **HEADERS}, + data=dumps(user_data)) + self.assertEqual(created_user.json, {'msg': "Password must be between 8 and 20 characters long, contain at " + "least one uppercase letter, one lowercase letter, one number " + "and one special character."}) + self.assertEqual(created_user.status_code, 400) + + mock_user.query.filter().first().change_password = lambda x: x/0 + created_user = client.put('/api/users', + headers={'Authorization': f'Bearer {123}', **HEADERS}, + data=dumps(user_data)) + self.assertEqual(created_user.json, {'msg': 'An unexpected error occurred'}) + self.assertEqual(created_user.status_code, 500) + @patch('ptmd.api.queries.users.User') @patch('ptmd.api.queries.users.get_jwt', return_value={'sub': 1}) def test_get_me(self, mock_jwt, mock_user, mock_get_current_user, mock_verify_jwt, mock_verify_in_request): @@ -258,13 +283,21 @@ def test_reset_password_failed(self, mock_get_current_user, mock_verify_jwt, moc @patch('ptmd.api.queries.users.get_token') def test_reset_password_error(self, mock_token, mock_get_current_user, mock_verify_jwt, mock_verify_in_request): - mock_token.side_effect = Exception('test') + mock_token.side_effect = PasswordPolicyError() headers = {'Authorization': f'Bearer {123}', **HEADERS} with app.test_client() as client: response = client.post('/api/users/reset/123', data=dumps({"password": "None"}), headers=headers) - self.assertEqual(response.json, {"msg": "test"}) + self.assertEqual(response.json, {"msg": "Password must be between 8 and 20 characters long, contain at " + "least one uppercase letter, one lowercase letter, one number " + "and one special character."}) self.assertEqual(response.status_code, 400) + mock_token.side_effect = Exception() + with app.test_client() as client: + response = client.post('/api/users/reset/123', data=dumps({"password": "None"}), headers=headers) + self.assertEqual(response.json, {"msg": "An unexpected error occurred"}) + self.assertEqual(response.status_code, 500) + @patch('ptmd.api.queries.users.get_token') @patch('ptmd.api.queries.users.session') def test_reset_password_success(self, mock_session, mock_token, diff --git a/tests/test_database/test_models/test_user.py b/tests/test_database/test_models/test_user.py index da1b9bc..27a9715 100644 --- a/tests/test_database/test_models/test_user.py +++ b/tests/test_database/test_models/test_user.py @@ -2,6 +2,7 @@ from unittest.mock import patch, mock_open from ptmd.database import User, Organisation, File +from ptmd.exceptions import PasswordPolicyError @patch("builtins.open", mock_open(read_data="{'save_credentials_file': 'test'}")) @@ -10,14 +11,14 @@ class TestUser(TestCase): @patch('ptmd.database.models.token.send_confirmation_mail', return_value=True) def test_user(self, mock_send_confirmation_mail): expected_user = {'files': [], 'id': None, 'organisation': None, 'username': 'test', 'role': 'disabled'} - user = User(username='test', password='test', email='your@email.com') + user = User(username='test', password='A!Str0ngPwd', email='your@email.com') self.assertEqual(dict(user), expected_user) - self.assertTrue(user.validate_password('test')) + self.assertTrue(user.validate_password('A!Str0ngPwd')) with patch('ptmd.database.models.user.session') as mock_session: - changed = user.change_password(old_password='test', new_password='test2') + changed = user.change_password(old_password='A!Str0ngPwd', new_password='A!Str0ngPwd2') self.assertTrue(changed) - changed = user.change_password(old_password='test', new_password='test2') + changed = user.change_password(old_password='test', new_password='A!Str0ngPwd') self.assertFalse(changed) with patch('ptmd.database.models.user.send_validation_mail') as mock_email: @@ -94,3 +95,12 @@ def test_user_serialisation_with_organisation(self, mock_organisation, mock_orga files = dict(user)['files'] self.assertIn(dict(file_1), files) self.assertIn(dict(file_2), files) + + @patch('ptmd.database.models.user.session') + def test_set_password_policy_failure(self, mock_session): + user = User(username='test', password='test', email='your@email.com', role='admin') + with self.assertRaises(PasswordPolicyError) as context: + user.set_password('test') + self.assertEqual(str(context.exception), + "Password must be between 8 and 20 characters long, contain at least one uppercase letter, one " + "lowercase letter, one number and one special character.") diff --git a/tests/test_database/test_queries/test_users.py b/tests/test_database/test_queries/test_users.py index 7fb50a4..feebce1 100644 --- a/tests/test_database/test_queries/test_users.py +++ b/tests/test_database/test_queries/test_users.py @@ -1,7 +1,9 @@ from unittest import TestCase from unittest.mock import patch +from datetime import datetime, timedelta -from ptmd.database.queries import login_user, create_organisations, create_users +from ptmd.database.queries import login_user, create_organisations, create_users, get_token +from ptmd.exceptions import TokenInvalidError, TokenExpiredError INPUTS_ORGS = {'KIT': {"g_drive": "123", "long_name": "test12"}} @@ -47,3 +49,19 @@ def test_create_users(self, mock_user, mock_organisation, mock_users_session, mo input_users = [{'username': 'test', 'password': 'test', 'organisation': organisations['KIT']}] user = create_users(users=input_users) self.assertEqual(user[0], 123) + + @patch('ptmd.database.queries.users.Token') + def test_get_token(self, mock_token): + mock_token.query.filter().first.return_value = None + with self.assertRaises(TokenInvalidError) as context: + get_token('ABC') + self.assertEqual(str(context.exception), 'Invalid token') + + class MockToken: + def __init__(self): + self.expires_on = datetime.now() - timedelta(days=10) + + mock_token.query.filter().first.return_value = MockToken() + with self.assertRaises(TokenExpiredError) as context: + get_token('ABC') + self.assertEqual(str(context.exception), 'Token expired') diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 0000000..958f1a4 --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,11 @@ +from unittest import TestCase + +from ptmd.exceptions import APIError + + +class TestExceptions(TestCase): + + def test_api_error(self): + with self.assertRaises(SyntaxError) as context: + APIError() + self.assertEqual(str(context.exception), 'Cannot instantiate abstract class APIError')