Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: send client version to server #267

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions diracx-routers/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ auth = "diracx.routers.auth:router"
WMSAccessPolicy = "diracx.routers.job_manager.access_policies:WMSAccessPolicy"
SandboxAccessPolicy = "diracx.routers.job_manager.access_policies:SandboxAccessPolicy"

[project.entry-points."diracx.min_client_version"]
diracx = "diracx.routers:DIRACX_MIN_CLIENT_VERSION"

[tool.setuptools.packages.find]
where = ["src"]
Expand Down
59 changes: 56 additions & 3 deletions diracx-routers/src/diracx/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,21 @@
import os
from collections.abc import AsyncGenerator
from functools import partial
from http import HTTPStatus
from importlib.metadata import EntryPoint, EntryPoints, entry_points
from logging import Formatter, StreamHandler
from typing import Any, Awaitable, Callable, Iterable, Sequence, TypeVar, cast

import dotenv
from cachetools import TTLCache
from fastapi import APIRouter, Depends, Request, status
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status
from fastapi.dependencies.models import Dependant
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response
from fastapi.routing import APIRoute
from packaging.version import parse
from pydantic import TypeAdapter
from starlette.middleware.base import BaseHTTPMiddleware

# from starlette.types import ASGIApp
from uvicorn.logging import AccessFormatter, DefaultFormatter
Expand Down Expand Up @@ -289,15 +293,14 @@ def create_app_inner(
origins = [
"http://localhost:8000",
]

app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

app.add_middleware(ClientMinVersionCheckMiddleware)
configure_logger()
instrument_otel(app)

Expand Down Expand Up @@ -437,3 +440,53 @@ async def db_transaction(db: T2) -> AsyncGenerator[T2, None]:
if reason := await is_db_unavailable(db):
raise DBUnavailable(reason)
yield db


class ClientMinVersionCheckMiddleware(BaseHTTPMiddleware):
"""Custom FastAPI middleware to verify that the client has the minimum version required."""

def __init__(self, app: FastAPI):
super().__init__(app)
self.min_client_version = get_min_client_version()

async def dispatch(self, request: Request, call_next) -> Response:
client_version = request.headers.get("X-DIRACX-CLIENT-VERSION")
print(f">>>>> {client_version}")
if not client_version:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Client version header is missing.",
)

if self.is_version_too_old(client_version):
raise HTTPException(
status_code=HTTPStatus.UPGRADE_REQUIRED,
detail=f"Client version ({client_version}) not recent enough (>= {self.min_client_version}). Upgrade.",
)

response = await call_next(request)
return response

def is_version_too_old(self, client_version: str) -> bool:
"""Verify that client version is ge than min."""
return parse(client_version) < parse(self.min_client_version)


# I'm not sure if this has to be define here:
DIRACX_MIN_CLIENT_VERSION = "0.1.1"


def get_min_client_version():
"""Extracting min client version from entry_points and seraching for extension."""
matched_entry_points: EntryPoints = entry_points(group="diracx.min_client_version")
# Searching for an extension:
entry_points_dict: dict[str, EntryPoint] = {
ep.name: ep for ep in matched_entry_points
}
for ep_name, ep in entry_points_dict.items():
if ep_name != "diracx":
return ep.load()

# Taking diracx if no extension:
if "diracx" in entry_points_dict:
return entry_points_dict["diracx"].load()
15 changes: 15 additions & 0 deletions diracx-routers/tests/test_generic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from http import HTTPStatus

import pytest
from fastapi import HTTPException

pytestmark = pytest.mark.enabled_dependencies(
["ConfigSource", "AuthSettings", "OpenAccessPolicy"]
Expand Down Expand Up @@ -41,3 +44,15 @@ def test_unavailable_db(monkeypatch, test_client):
r = test_client.get("/api/job/123")
assert r.status_code == 503
assert r.json()


def test_min_client_version(test_client):
with pytest.raises(HTTPException) as response:
test_client.get("/", headers={"X-DIRACX-CLIENT-VERSION": "0.1.0"})
assert response.value.status_code == HTTPStatus.UPGRADE_REQUIRED
assert "not recent enough" in response.value.detail

with pytest.raises(HTTPException) as response:
test_client.get("/", headers={})
assert response.value.status_code == HTTPStatus.BAD_REQUEST
assert "header is missing" in response.value.detail
Loading