Skip to content

Commit

Permalink
Merge pull request #85 from 8thgencore/main
Browse files Browse the repository at this point in the history
Migrate from passlib to bcrypt
  • Loading branch information
jonra1993 committed Sep 9, 2023
2 parents c07be9a + 8ab49fc commit 863acdc
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 71 deletions.
41 changes: 25 additions & 16 deletions backend/app/app/api/deps.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
from collections.abc import AsyncGenerator
from typing import Callable

import redis.asyncio as aioredis
from fastapi import Depends, HTTPException, status
from app.utils.token import get_valid_tokens
from app.utils.minio_client import MinioClient
from fastapi.security import OAuth2PasswordBearer
from jose import jwt
from app.models.user_model import User
from pydantic import ValidationError
from jwt import DecodeError, ExpiredSignatureError, MissingRequiredClaimError
from redis.asyncio import Redis
from sqlmodel.ext.asyncio.session import AsyncSession

from app import crud
from app.core import security
from app.core.config import settings
from app.core.security import decode_token
from app.db.session import SessionLocal, SessionLocalCelery
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.user_model import User
from app.schemas.common_schema import IMetaGeneral, TokenType
import redis.asyncio as aioredis
from redis.asyncio import Redis

from app.utils.minio_client import MinioClient
from app.utils.token import get_valid_tokens

reusable_oauth2 = OAuth2PasswordBearer(
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
Expand Down Expand Up @@ -49,23 +49,32 @@ async def get_general_meta() -> IMetaGeneral:

def get_current_user(required_roles: list[str] = None) -> Callable[[], User]:
async def current_user(
token: str = Depends(reusable_oauth2),
access_token: str = Depends(reusable_oauth2),
redis_client: Redis = Depends(get_redis_client),
) -> User:
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
payload = decode_token(access_token)
except ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Your token has expired. Please log in again.",
)
except (jwt.JWTError, ValidationError):
except DecodeError:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
detail="Error when decoding the token. Please check your request.",
)
except MissingRequiredClaimError:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="There is no required field in your token. Please contact the administrator.",
)

user_id = payload["sub"]
valid_access_tokens = await get_valid_tokens(
redis_client, user_id, TokenType.ACCESS
)
if valid_access_tokens and token not in valid_access_tokens:
if valid_access_tokens and access_token not in valid_access_tokens:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
Expand Down
45 changes: 27 additions & 18 deletions backend/app/app/api/v1/endpoints/login.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
from datetime import timedelta
from fastapi import APIRouter, Body, Depends, HTTPException
from redis.asyncio import Redis
from app.utils.token import get_valid_tokens
from app.utils.token import delete_tokens
from app.utils.token import add_token_to_redis
from app.core.security import get_password_hash
from app.core.security import verify_password
from app.models.user_model import User
from app.api.deps import get_redis_client

from fastapi import APIRouter, Body, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from jose import jwt
from jwt import DecodeError, ExpiredSignatureError, MissingRequiredClaimError
from pydantic import EmailStr
from pydantic import ValidationError
from redis.asyncio import Redis

from app import crud
from app.api import deps
from app.api.deps import get_redis_client
from app.core import security
from app.core.config import settings
from app.schemas.common_schema import TokenType, IMetaGeneral
from app.schemas.token_schema import TokenRead, Token, RefreshToken
from app.core.security import decode_token, get_password_hash, verify_password
from app.models.user_model import User
from app.schemas.common_schema import IMetaGeneral, TokenType
from app.schemas.response_schema import IPostResponseBase, create_response
from app.schemas.token_schema import RefreshToken, Token, TokenRead
from app.utils.token import add_token_to_redis, delete_tokens, get_valid_tokens

router = APIRouter()

Expand Down Expand Up @@ -147,11 +145,22 @@ async def get_new_access_token(
Gets a new access token using the refresh token for future requests
"""
try:
payload = jwt.decode(
body.refresh_token, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
payload = decode_token(body.refresh_token)
except ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Your token has expired. Please log in again.",
)
except DecodeError:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Error when decoding the token. Please check your request.",
)
except MissingRequiredClaimError as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="There is no required field in your token. Please contact the administrator.",
)
except (jwt.JWTError, ValidationError):
raise HTTPException(status_code=403, detail="Refresh token invalid")

if payload["type"] == "refresh":
user_id = payload["sub"]
Expand All @@ -163,7 +172,7 @@ async def get_new_access_token(

access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
user = await crud.user.get(id=user_id)
if getattr(user, "is_active"):
if user.is_active:
access_token = security.create_access_token(
payload["sub"], expires_delta=access_token_expires
)
Expand Down
49 changes: 37 additions & 12 deletions backend/app/app/core/security.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from datetime import datetime, timedelta
from typing import Any

import bcrypt
import jwt
from cryptography.fernet import Fernet
from jose import jwt
from passlib.context import CryptContext

from app.core.config import settings

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
fernet = Fernet(str.encode(settings.ENCRYPT_KEY))

ALGORITHM = "HS256"
JWT_ALGORITHM = "HS256"


def create_access_token(subject: str | Any, expires_delta: timedelta = None) -> str:
Expand All @@ -19,8 +20,12 @@ def create_access_token(subject: str | Any, expires_delta: timedelta = None) ->
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode = {"exp": expire, "sub": str(subject), "type": "access"}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt

return jwt.encode(
payload=to_encode,
key=settings.ENCRYPT_KEY,
algorithm=JWT_ALGORITHM,
)


def create_refresh_token(subject: str | Any, expires_delta: timedelta = None) -> str:
Expand All @@ -31,16 +36,36 @@ def create_refresh_token(subject: str | Any, expires_delta: timedelta = None) ->
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
)
to_encode = {"exp": expire, "sub": str(subject), "type": "refresh"}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt

return jwt.encode(
payload=to_encode,
key=settings.ENCRYPT_KEY,
algorithm=JWT_ALGORITHM,
)


def decode_token(token: str) -> dict[str, Any]:
return jwt.decode(
jwt=token,
key=settings.ENCRYPT_KEY,
algorithms=[JWT_ALGORITHM],
)


def verify_password(plain_password: str | bytes, hashed_password: str | bytes) -> bool:
if isinstance(plain_password, str):
plain_password = plain_password.encode()
if isinstance(hashed_password, str):
hashed_password = hashed_password.encode()

return bcrypt.checkpw(plain_password, hashed_password)

def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)

def get_password_hash(plain_password: str | bytes) -> str:
if isinstance(plain_password, str):
plain_password = plain_password.encode()

def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
return bcrypt.hashpw(plain_password, bcrypt.gensalt()).decode()


def get_data_encrypt(data) -> str:
Expand Down
59 changes: 37 additions & 22 deletions backend/app/app/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import gc
import logging
from contextlib import asynccontextmanager
from typing import Any
from uuid import UUID, uuid4
from app import crud
from app.schemas.common_schema import IChatResponse, IUserMessage
from app.utils.uuid6 import uuid7

from fastapi import (
FastAPI,
HTTPException,
Expand All @@ -13,25 +12,27 @@
WebSocketDisconnect,
status,
)
from app.core import security
from app.api.deps import get_redis_client
from fastapi_pagination import add_pagination
from pydantic import ValidationError
from starlette.middleware.cors import CORSMiddleware
from app.api.v1.api import api_router as api_router_v1
from app.core.config import ModeEnum, settings
from fastapi_async_sqlalchemy import SQLAlchemyMiddleware, db
from fastapi_cache import FastAPICache
from fastapi_cache.backends.redis import RedisBackend
from fastapi_async_sqlalchemy import SQLAlchemyMiddleware, db
from contextlib import asynccontextmanager
from app.utils.fastapi_globals import g, GlobalsMiddleware
from transformers import pipeline
from fastapi_limiter import FastAPILimiter
from jose import jwt
from fastapi_limiter.depends import WebSocketRateLimiter
from fastapi_pagination import add_pagination
from jwt import DecodeError, ExpiredSignatureError, MissingRequiredClaimError
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
from sqlalchemy.pool import NullPool, QueuePool
from starlette.middleware.cors import CORSMiddleware
from transformers import pipeline

from app import crud
from app.api.deps import get_redis_client
from app.api.v1.api import api_router as api_router_v1
from app.core.config import ModeEnum, settings
from app.core.security import decode_token
from app.schemas.common_schema import IChatResponse, IUserMessage
from app.utils.fastapi_globals import GlobalsMiddleware, g
from app.utils.uuid6 import uuid7


async def user_id_identifier(request: Request):
Expand All @@ -45,16 +46,25 @@ async def user_id_identifier(request: Request):
if len(header_parts) == 2 and header_parts[0].lower() == "bearer":
token = header_parts[1]
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
payload = decode_token(token)
except ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Your token has expired. Please log in again.",
)
except (jwt.JWTError, ValidationError):
except DecodeError:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
detail="Error when decoding the token. Please check your request.",
)
except MissingRequiredClaimError:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="There is no required field in your token. Please contact the administrator.",
)

user_id = payload["sub"]
print("here2", user_id)

return user_id

if request.scope["type"] == "websocket":
Expand All @@ -65,7 +75,7 @@ async def user_id_identifier(request: Request):
return forwarded.split(",")[0]

client = request.client
ip = getattr(client, "host", "0.0.0.0")
ip = getattr(client, "host", "0.0.0.0")
return ip + ":" + request.scope["path"]


Expand Down Expand Up @@ -134,7 +144,12 @@ class CustomException(Exception):
code: str
message: str

def __init__(self, http_code: int = 500, code: str | None = None, message: str = 'This is an error message'):
def __init__(
self,
http_code: int = 500,
code: str | None = None,
message: str = "This is an error message",
):
self.http_code = http_code
self.code = code if code else str(self.http_code)
self.message = message
Expand Down
6 changes: 3 additions & 3 deletions backend/app/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ alembic = "^1.10.2"
asyncpg = "^0.27.0"
fastapi = {extras = ["all"], version = "^0.95.2"}
sqlmodel = "^0.0.8"
python-jose = "^3.3.0"
cryptography = "^38.0.3"
passlib = "^1.7.4"
cryptography = "^41.0.3"
bcrypt = "^4.0.1"
pyjwt = { extras = ["crypto"], version = "^2.8.0" }
SQLAlchemy-Utils = "^0.38.3"
SQLAlchemy = "^1.4.40"
fastapi-pagination = {extras = ["sqlalchemy"], version = "^0.11.4"}
Expand Down

0 comments on commit 863acdc

Please sign in to comment.