diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..e54c5a6 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,4 @@ +[run] +omit = + # Omit any tests + */tests* diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..449e476 --- /dev/null +++ b/.flake8 @@ -0,0 +1,7 @@ +[flake8] +extend-ignore = E203, E266, E501 +# line length is intentionally set to 80 here because black uses Bugbear +# See https://github.com/psf/black/blob/master/docs/the_black_code_style.md#line-length for more details +max-line-length = 80 +max-complexity = 18 +select = B,C,E,F,W,T4,B9 diff --git a/.github/workflows/docker-hub-image-build.yml b/.github/workflows/docker-hub-image-build.yml new file mode 100644 index 0000000..1194a03 --- /dev/null +++ b/.github/workflows/docker-hub-image-build.yml @@ -0,0 +1,48 @@ +name: Build image for Docker Hub + +on: + release: + types: + - "released" + +jobs: + main: + runs-on: ubuntu-20.04 + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Set up QEMU + uses: docker/setup-qemu-action@v1 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v1 + + - name: Get the version + id: get_version + run: echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//} + + - name: Login to DockerHub + uses: docker/login-action@v1 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_TOKEN }} + + - name: Build and push + id: docker_build + uses: docker/build-push-action@v2 + with: + context: . + file: ./Dockerfile.prod + platforms: linux/amd64 + build-args: | + release_version=${{ steps.get_version.outputs.VERSION }} + push: true + cache-from: type=local,src=/tmp/.buildx-cache + cache-to: type=local,dest=/tmp/.buildx-cache + tags: | + onaio/duva:latest + onaio/duva:${{ steps.get_version.outputs.VERSION }} + + - name: Image digest + run: echo ${{ steps.docker_build.outputs.digest }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f498d23 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +__pycache__/ +media/* +.vscode/ +hyperd*.log +.DS_Store +*.db +MANIFEST +.tox/ +.pytest_cache/ +.coverage diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..eae34c4 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,9 @@ +repos: + - repo: https://github.com/psf/black + rev: 20.8b1 + hooks: + - id: black + - repo: https://gitlab.com/pycqa/flake8 + rev: 3.8.4 + hooks: + - id: flake8 diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..e5d6cf2 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,19 @@ +sudo: required +dist: focal +language: python +jobs: + include: + - python: 3.7 + env: TOXENV=py37 + - python: 3.8 + env: TOXENV=py38 + - python: 3.7 + env: TOXENV=lint +services: + - redis-server +install: + - pip install -U pip + - pip install tox +script: tox +notifications: + slack: onaio:snkNXgprD498qQv4DgRREKJF \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..fe1acdd --- /dev/null +++ b/Dockerfile @@ -0,0 +1,6 @@ +FROM tiangolo/uvicorn-gunicorn-fastapi:python3.7 + +RUN mkdir -p /root/.aws +COPY . /app + +RUN mkdir -p /app/media && pip install --no-cache-dir -r /app/requirements.pip diff --git a/Dockerfile.prod b/Dockerfile.prod new file mode 100644 index 0000000..a1df38e --- /dev/null +++ b/Dockerfile.prod @@ -0,0 +1,20 @@ +FROM tiangolo/uvicorn-gunicorn-fastapi:python3.7 +ARG release_version=v0.0.1 + +# Create application user +RUN useradd -m duva + +# Create directory for AWS Configurations +RUN mkdir -p /home/duva/.aws + +# Clone Duva application source code +RUN git clone -b ${release_version} https://github.com/onaio/duva.git /app-cloned &&\ + mv -f /app-cloned/* /app &&\ + chown -R duva:duva /app + +# Install application requirements +RUN pip install --no-cache-dir -U pip && pip install --no-cache-dir -r /app/requirements.pip + +EXPOSE 8000 + +CMD ["/start.sh"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..a909a50 --- /dev/null +++ b/README.md @@ -0,0 +1,86 @@ +# Duva + +[![Build Status](https://travis-ci.com/onaio/duva.svg?branch=main)](https://travis-ci.com/github/onaio/duva) + +Duva is an API built using the [FastAPI](https://github.com/tiangolo/fastapi) framework that provides functionality to create & periodically update Tableau [Hyper](https://www.tableau.com/products/new-features/hyper) databases from CSV files. Currently the application supports connection to an [OnaData](https://github.com/onaio/onadata) server from which it'll pull data from an XLSForm and periodically export to a Tableau Hyper database + +## Requirements + +- Python 3.6+ +- Redis + +## Installation + +### Via Docker + +The application comes with a `docker-compose.yml` file to facilitate easier installation of the project. _Note: The `docker-compose.yml` file is tailored for development environments_ + +To start up the application via [Docker](https://www.docker.com/products/docker-desktop) run the `docker-compose up` command. + +### Alternative Installation + +1. Clone repository + +```sh +$ git clone https://github.com/onaio/duva.git +``` + +2. Create & start [a virtual environment](https://virtualenv.pypa.io/en/latest/installation.html) to install dependencies + +```sh +$ virtualenv duva +$ source duva/bin/activate +``` + +3. Install base dependencies + +```sh +$ pip install -r requirements.pip +``` + +4. (Optional: For developer environments) Install development dependencies. + +```sh +$ pip install -r dev-requirements.pip +``` + +At this point the application can be started. _Note: Ensure the redis server has been started_ + +``` +$ ./scripts/start.sh +``` + +## Configuration + +The application can be configured either by manual editing of the `app/settings.py` file or via environment variables i.e `export APP_NAME="Duva"`. More information on this [here](https://fastapi.tiangolo.com/advanced/settings) + +## API Documentation + +Documentation on the API endpoints provided by the application can be accessed by first running the application and accessing the `/docs` route. + +## Testing + +This project utilizes [tox](https://tox.readthedocs.io/en/latest/) for testing. In order to run the test suite within this project run the following commands: + +``` +$ pip install tox +$ tox +``` + +Alternatively, if you'd like to test the application with only the python version currently installed in your computer follow these steps: + +1. Install the developer dependencies + +```sh +$ pip install -r dev-requirements +``` + +2. Run the test suite using [pytest](https://docs.pytest.org/en/stable/) + +```sh +$ ./scripts/run-tests.sh +``` +>> OR +```sh +$ PYTHONPATH=. pytest -s app/tests +``` diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md new file mode 100644 index 0000000..3c455b2 --- /dev/null +++ b/RELEASE_NOTES.md @@ -0,0 +1,43 @@ +# Release Notes + +All release notes for this project will be documented in this file; this project follows [Semantic Versioning](https://semver.org/). + +## v0.0.1 - 2021-03-15 + +This is the first release :confetti_ball:. + +Project Breakdown: Duva is RESTful API that allows users to easily create & manage [Tableau Hyper](https://www.tableau.com/products/new-features/hyper) databases. + +### Key Features as of v0.0.1: + +- Supports automatic creation and updates of Hyper databases from an [OnaData](https://github.com/onaio/onadata) server; The application utilizes OnaDatas Export functionality to create and update the database. +- Supports creation of Hyper databases from a CSV File. + +### Sample Flows: + +#### One-off Hyper database creation from CSV File: + +The application as mentioned above supports creation of a one-time Hyper database from a CSV File; These databases are not updated after creation. + +![one-off hyper database creation](./docs/flow-diagrams/one-off-hyper-database-flow.png) + +This flow is ideal for one-off hyper database or for Servers where automatic creation & updates are not supported. *NB: As of v0.0.1 the application only supports OnaData servers.* + +#### Automatic creation and updates of Hyper Databases for OnaData servers + +In order for one to use this flow with a desired server, the user has to first register a new `Server` object. Which will be used to authenticate the application and users; allowing the application to pull data on behalf of the user on a scheduled basis in order to update the managed Hyper database. + +Server registration flow(One-time flow for new servers): + +![server registration flow](./docs/flow-diagrams/server-registration-flow.png) + +After a new server is registered users from the registered server are now able to create +managed Hyper database files. + +![managed hyper datase flow](./docs/flow-diagrams/managed-hyper-database-flow.png) + +*During the creation of the managed hyper database, users can specify a Tableau server where the hyper database should be published too after every update of the hyper database. For more information on how to configure this please view the API Docs on a deployed instance of the application(/docs).* + +### Known Limitations of v0.0.1 + +- The application currently uses session cookies to authenticate users; there are plans to phase out session cookies in favor of API Tokens. As of now users may need to clear the cookies in order to unauthenticate. \ No newline at end of file diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..39e691a --- /dev/null +++ b/alembic.ini @@ -0,0 +1,82 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = app/alembic + +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# timezone to use when rendering the date +# within the migration file as well as the filename. +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; this defaults +# to app/alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path +# version_locations = %(here)s/bar %(here)s/bat app/alembic/versions + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +hooks=black +black.type=console_scripts +black.entrypoint=black +black.options=-l 79 + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/alembic/README b/app/alembic/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/app/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/app/alembic/env.py b/app/alembic/env.py new file mode 100644 index 0000000..d4c623a --- /dev/null +++ b/app/alembic/env.py @@ -0,0 +1,71 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context +from app.models import Base +from app.settings import settings + + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) + +config.set_main_option("sqlalchemy.url", settings.database_url) + +target_metadata = Base.metadata + + +def run_migrations_offline(): + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online(): + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/app/alembic/script.py.mako b/app/alembic/script.py.mako new file mode 100644 index 0000000..2c01563 --- /dev/null +++ b/app/alembic/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(): + ${upgrades if upgrades else "pass"} + + +def downgrade(): + ${downgrades if downgrades else "pass"} diff --git a/app/alembic/versions/0b9ae1eb0b30_initial_migration.py b/app/alembic/versions/0b9ae1eb0b30_initial_migration.py new file mode 100644 index 0000000..3a41043 --- /dev/null +++ b/app/alembic/versions/0b9ae1eb0b30_initial_migration.py @@ -0,0 +1,91 @@ +"""Initial migration + +Revision ID: 0b9ae1eb0b30 +Revises: +Create Date: 2021-01-21 15:16:36.435591 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "0b9ae1eb0b30" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "server", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("url", sa.String(), nullable=True), + sa.Column("client_id", sa.String(), nullable=True), + sa.Column("client_secret", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("url"), + ) + op.create_table( + "user", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("username", sa.String(), nullable=True), + sa.Column("refresh_token", sa.String(), nullable=True), + sa.Column("server", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["server"], ["server.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("server", "username", name="_server_user_uc"), + ) + op.create_index(op.f("ix_user_id"), "user", ["id"], unique=False) + op.create_table( + "configuration", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("server_address", sa.String(), nullable=True), + sa.Column("site_name", sa.String(), nullable=True), + sa.Column("token_name", sa.String(), nullable=True), + sa.Column("token_value", sa.String(), nullable=True), + sa.Column("project_name", sa.String(), nullable=True), + sa.Column("user", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["user"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint( + "server_address", + "token_name", + "user", + name="_server_token_name_uc", + ), + ) + op.create_index(op.f("ix_configuration_id"), "configuration", ["id"], unique=False) + op.create_table( + "hyper_file", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("filename", sa.String(), nullable=True), + sa.Column("user", sa.Integer(), nullable=True), + sa.Column("form_id", sa.Integer(), nullable=False), + sa.Column("last_updated", sa.DateTime(), nullable=True), + sa.Column("last_synced", sa.DateTime(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=True), + sa.Column("configuration_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["configuration_id"], ["configuration.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint(["user"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("filename", name="hyper_file_filename_key"), + sa.UniqueConstraint("user", "form_id", name="_user_form_id_uc"), + ) + op.create_index(op.f("ix_hyper_file_id"), "hyper_file", ["id"], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_hyper_file_id"), table_name="hyper_file") + op.drop_table("hyper_file") + op.drop_index(op.f("ix_configuration_id"), table_name="configuration") + op.drop_table("configuration") + op.drop_index(op.f("ix_user_id"), table_name="user") + op.drop_table("user") + op.drop_table("server") + # ### end Alembic commands ### diff --git a/app/alembic/versions/60383c2a9b44_alter_hyper_file_filename_field_to_be_.py b/app/alembic/versions/60383c2a9b44_alter_hyper_file_filename_field_to_be_.py new file mode 100644 index 0000000..c1126bb --- /dev/null +++ b/app/alembic/versions/60383c2a9b44_alter_hyper_file_filename_field_to_be_.py @@ -0,0 +1,31 @@ +"""Alter Hyper file filename field to be non-unique + +Revision ID: 60383c2a9b44 +Revises: 0b9ae1eb0b30 +Create Date: 2021-02-12 16:44:45.035342 + +""" +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "60383c2a9b44" +down_revision = "0b9ae1eb0b30" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + try: + op.drop_constraint("hyper_file_filename_key", "hyper_file") + except NotImplementedError: + with op.batch_alter_table("hyper_file", schema=None) as batch_op: + batch_op.drop_constraint("hyper_file_filename_key") + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_unique_constraint("hyper_file_filename_key", "hyper_file", ["filename"]) + # ### end Alembic commands ### diff --git a/app/alembic/versions/8a3e2f1927b8_add_file_status_field.py b/app/alembic/versions/8a3e2f1927b8_add_file_status_field.py new file mode 100644 index 0000000..ab34484 --- /dev/null +++ b/app/alembic/versions/8a3e2f1927b8_add_file_status_field.py @@ -0,0 +1,32 @@ +"""Add file_status field + +Revision ID: 8a3e2f1927b8 +Revises: 60383c2a9b44 +Create Date: 2021-02-23 11:30:23.732253 + +""" +from app.models import ChoiceType, schemas +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "8a3e2f1927b8" +down_revision = "60383c2a9b44" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "hyper_file", + sa.Column("file_status", ChoiceType(schemas.FileStatusEnum), nullable=True), + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("hyper_file", "file_status") + # ### end Alembic commands ### diff --git a/app/alembic/versions/e28d24caaf56_add_meta_data_field.py b/app/alembic/versions/e28d24caaf56_add_meta_data_field.py new file mode 100644 index 0000000..819d293 --- /dev/null +++ b/app/alembic/versions/e28d24caaf56_add_meta_data_field.py @@ -0,0 +1,28 @@ +"""Add meta_data field + +Revision ID: e28d24caaf56 +Revises: 8a3e2f1927b8 +Create Date: 2021-04-19 16:42:54.162670 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "e28d24caaf56" +down_revision = "8a3e2f1927b8" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("hyper_file", sa.Column("meta_data", sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("hyper_file", "meta_data") + # ### end Alembic commands ### diff --git a/app/common_tags.py b/app/common_tags.py new file mode 100644 index 0000000..37bb537 --- /dev/null +++ b/app/common_tags.py @@ -0,0 +1,11 @@ +# Common Tags +HYPER_PROCESS_CACHE_KEY = "HYPER_PROCESS" + +EVENT_STATUS_SUFFIX = "-event-status" + +ONADATA_TOKEN_ENDPOINT = "/o/token/" +ONADATA_FORMS_ENDPOINT = "/api/v1/forms" +ONADATA_USER_ENDPOINT = "/api/v1/user" + +SYNC_FAILURES_METADATA = "sync-failures" +JOB_ID_METADATA = "job-id" diff --git a/app/database.py b/app/database.py new file mode 100644 index 0000000..48c3761 --- /dev/null +++ b/app/database.py @@ -0,0 +1,11 @@ +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +from app.settings import settings + + +engine = create_engine(settings.database_url, connect_args=settings.db_connect_args) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +Base = declarative_base() diff --git a/app/jobs/jobs.py b/app/jobs/jobs.py new file mode 100644 index 0000000..1ee6c8c --- /dev/null +++ b/app/jobs/jobs.py @@ -0,0 +1,26 @@ +from fastapi_cache import caches + +from fastapi_cache.backends.redis import CACHE_KEY, RedisCacheBackend +from tableauhyperapi import HyperProcess, Telemetry +from app.common_tags import HYPER_PROCESS_CACHE_KEY +from app.settings import settings + +from app.utils.onadata_utils import start_csv_import_to_hyper + + +def csv_import_job(instance_id): + # Connect to redis cache + rc = RedisCacheBackend(settings.redis_url) + caches.set(CACHE_KEY, rc) + + # Check if Hyper Process has started + # Note: Doing this in order to ensure only one + # Hyper process is started. + if not caches.get(HYPER_PROCESS_CACHE_KEY): + caches.set( + HYPER_PROCESS_CACHE_KEY, + HyperProcess(telemetry=Telemetry.SEND_USAGE_DATA_TO_TABLEAU), + ) + process: HyperProcess = caches.get(HYPER_PROCESS_CACHE_KEY) + + start_csv_import_to_hyper(instance_id, process) diff --git a/app/jobs/scheduler.py b/app/jobs/scheduler.py new file mode 100644 index 0000000..9e30a59 --- /dev/null +++ b/app/jobs/scheduler.py @@ -0,0 +1,47 @@ +import os + +from redis import Redis +from rq import Queue +from rq.job import Job +from rq_scheduler import Scheduler +from typing import Callable + +QUEUE_NAME = os.environ.get("QUEUE_NAME", "default") +CRON_SCHEDULE = os.environ.get("CRON_SCHEDULE", "*/15 * * * *") +TASK_TIMEOUT = os.environ.get("TASK_TIMEOUT", "3600") +REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/1") +REDIS_CONN = Redis.from_url(REDIS_URL) +QUEUE = Queue(QUEUE_NAME, connection=REDIS_CONN) +SCHEDULER = Scheduler(queue=QUEUE, connection=REDIS_CONN) + + +def cancel_job(job_id, job_args: list = None, func_name: str = None): + SCHEDULER.cancel(job_id) + + if job_args and func_name: + for job in SCHEDULER.get_jobs(): + if job.func_name == func_name and job.args == job_args: + SCHEDULER.cancel(job) + + print(f"Job {job_id} cancelled ....") + + +def clear_scheduler_queue(): + for job in SCHEDULER.get_jobs(): + cancel_job(job) + + +def schedule_cron_job(job_func: Callable, args_list) -> Job: + job = SCHEDULER.cron( + CRON_SCHEDULE, # A cron string (e.g. "0 0 * * 0") + func=job_func, # Function to be queued + args=args_list, # Arguments passed into function when executed + kwargs={}, # Keyword arguments passed into function when executed + repeat=None, # Repeat this number of times (None means repeat forever) + queue_name=QUEUE_NAME, # In which queue the job should be put in + meta={}, # Arbitrary pickleable data on the job itself + use_local_timezone=False, # Interpret hours in the local timezone + timeout=int(TASK_TIMEOUT), # How long jobs can run for + ) + print(f"Job {job.id} scheduled ....") + return job diff --git a/app/jobs/settings.py b/app/jobs/settings.py new file mode 100644 index 0000000..459ee79 --- /dev/null +++ b/app/jobs/settings.py @@ -0,0 +1,15 @@ +""" +Settings file for RQ Workers +""" +import os +import sentry_sdk +from sentry_sdk.integrations.rq import RqIntegration + +from app.settings import settings + +# Init sentry +if settings.sentry_dsn: + sentry_sdk.init(settings.sentry_dsn, integrations=[RqIntegration()]) + +REDIS_URL = settings.redis_url +QUEUES = [os.environ.get("QUEUE_NAME", "default")] diff --git a/app/jobs/worker.py b/app/jobs/worker.py new file mode 100644 index 0000000..1bd0cc1 --- /dev/null +++ b/app/jobs/worker.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +import os + +import sentry_sdk +from app.settings import settings +from redis import Redis +from rq import Connection, Worker, Queue +from sentry_sdk.integrations.rq import RqIntegration + +# Preload libraries + +QUEUE_NAME = os.environ.get("QUEUE_NAME", "default") + + +redis_conn = Redis.from_url(settings.redis_url) + +# Provide queue names to listen to as arguments to this script, +# similar to rq worker +with Connection(): + if settings.sentry_dsn: + sentry_sdk.init(settings.sentry_dsn, integrations=[RqIntegration()]) + queue = Queue(QUEUE_NAME, connection=redis_conn) + + w = Worker(queue, connection=redis_conn) + w.work() diff --git a/app/libs/s3/client.py b/app/libs/s3/client.py new file mode 100644 index 0000000..43e35b1 --- /dev/null +++ b/app/libs/s3/client.py @@ -0,0 +1,64 @@ +import boto3 + +from botocore.exceptions import ClientError + +from app.settings import settings + + +class S3Client: + """ + This class encapsulates s3 client provided by boto3 + + """ + + def __init__(self): + self.s3 = boto3.resource("s3", region_name=settings.s3_region) + + def upload(self, path, file_name): + """ + uploads file in the given path to s3 with the given filename + """ + try: + self.s3.meta.client.upload_file(path, settings.s3_bucket, file_name) + except ClientError: + return False + return True + + def download(self, path, file_name): + """ + Downloads file_name in s3 to path + """ + try: + self.s3.meta.client.download_file(settings.s3_bucket, file_name, path) + except ClientError: + return False + return True + + def delete(self, file_path): + """ + Deletes file_path in S3 + """ + try: + resp = self.s3.meta.client.delete_object( + Bucket=settings.s3_bucket, Key=file_path + ) + except ClientError: + return False + return resp.get("DeleteMarker") + + def generate_presigned_download_url(self, file_path: str, expiration: int = 3600): + """ + Generates a presigned Download URL + + file_path :string: Path to the file in the S3 Bucket + expirationg :integer: The duration in seconds that the URL should be valid for + """ + try: + response = self.s3.meta.client.generate_presigned_url( + "get_object", + Params={"Bucket": settings.s3_bucket, "Key": file_path}, + ExpiresIn=expiration, + ) + except ClientError: + return None + return response diff --git a/app/libs/tableau/client.py b/app/libs/tableau/client.py new file mode 100644 index 0000000..aef5168 --- /dev/null +++ b/app/libs/tableau/client.py @@ -0,0 +1,50 @@ +import tableauserverclient as TSC + +from pathlib import Path + +from app.models import Configuration + + +class TableauClient: + def __init__(self, configuration: Configuration): + self.project_name = configuration.project_name + self.token_name = configuration.token_name + self.token_value = Configuration.decrypt_value(configuration.token_value) + self.site_name = configuration.site_name + self.server_address = configuration.server_address + + def publish_hyper(self, hyper_name): + """ + Signs in and publishes an extract directly to Tableau Online/Server + """ + + # Sign in to server + tableau_auth = TSC.PersonalAccessTokenAuth( + token_name=self.token_name, + personal_access_token=self.token_value, + site_id=self.site_name, + ) + server = TSC.Server(self.server_address, use_server_version=True) + + print(f"Signing into {self.site_name} at {self.server_address}") + with server.auth.sign_in(tableau_auth): + # Define publish mode - Overwrite, Append, or CreateNew + publish_mode = TSC.Server.PublishMode.Overwrite + + # Get project_id from project_name + # all_projects, _ = server.projects.get() + for project in TSC.Pager(server.projects): + if project.name == self.project_name: + project_id = project.id + + # Create the datasource object with the project_id + datasource = TSC.DatasourceItem(project_id) + + print(f"Publishing {hyper_name} to {self.project_name}...") + + path_to_database = Path(hyper_name) + # Publish datasource + datasource = server.datasources.publish( + datasource, path_to_database, publish_mode + ) + print("Datasource published. Datasource ID: {0}".format(datasource.id)) diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..0d4ea01 --- /dev/null +++ b/app/main.py @@ -0,0 +1,111 @@ +import os +import uvicorn +import sentry_sdk +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.templating import Jinja2Templates +from fastapi_cache import caches, close_caches +from fastapi_cache.backends.redis import CACHE_KEY, RedisCacheBackend +from tableauhyperapi import HyperProcess, Telemetry +from sentry_sdk.integrations.asgi import SentryAsgiMiddleware +from starlette.middleware.sessions import SessionMiddleware + +from app.common_tags import HYPER_PROCESS_CACHE_KEY +from app.database import engine +from app.models import Base +from app.settings import settings +from app.utils.onadata_utils import schedule_all_active_forms +from app.routers.file import router as file_router +from app.routers.oauth import router as oauth_router +from app.routers.server import router as server_router +from app.routers.configuration import router as configurations_router +from app.jobs.scheduler import clear_scheduler_queue + +Base.metadata.create_all(bind=engine) + +app = FastAPI( + title=settings.app_name, + description=settings.app_description, + version=settings.app_version, +) + +templates = Jinja2Templates(directory="app/templates") + +# Include middlewares +app.add_middleware( + SessionMiddleware, + secret_key=settings.secret_key, + https_only=settings.enable_secure_sessions, + same_site=settings.session_same_site, +) +app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_allowed_origins, + allow_credentials=settings.cors_allow_credentials, + allow_methods=settings.cors_allowed_methods, + allow_headers=settings.cors_allowed_headers, + max_age=settings.cors_max_age, +) +if settings.sentry_dsn: + sentry_sdk.init(dsn=settings.sentry_dsn, release=settings.app_version) + app.add_middleware(SentryAsgiMiddleware) + +# Include routes +app.include_router(server_router, tags=["Server Configuration"]) +app.include_router(oauth_router, tags=["OAuth2"]) +app.include_router(file_router, tags=["Hyper File"]) +app.include_router(configurations_router, tags=["Tableau Server Configuration"]) + + +@app.get("/", tags=["Application"]) +def home(request: Request): + return { + "app_name": settings.app_name, + "app_description": settings.app_description, + "app_version": settings.app_version, + "docs_url": str(request.base_url.replace(path=app.docs_url)), + } + + +@app.on_event("startup") +async def on_startup() -> None: + # Ensure media file path exists + if not os.path.isdir(settings.media_path): + os.mkdir(settings.media_path) + + # Connect to redis cache + rc = RedisCacheBackend(settings.redis_url) + caches.set(CACHE_KEY, rc) + + # Check if Hyper Process has started + # Note: Doing this in order to ensure only one + # Hyper process is started. + if not caches.get(HYPER_PROCESS_CACHE_KEY): + caches.set( + HYPER_PROCESS_CACHE_KEY, + HyperProcess(telemetry=Telemetry.SEND_USAGE_DATA_TO_TABLEAU), + ) + + if settings.schedule_all_active: + clear_scheduler_queue() + schedule_all_active_forms(close_db=True) + + +@app.on_event("shutdown") +async def on_shutdown() -> None: + await close_caches() + + # Check if hyper process is running and shut it down + process: HyperProcess = caches.get(HYPER_PROCESS_CACHE_KEY) + if process: + print("Shutting down hyper process") + process.close() + + +if __name__ == "__main__": + uvicorn.run( + "app.main:app", + host=settings.app_host, + port=settings.app_port, + reload=settings.debug, + ) diff --git a/app/models.py b/app/models.py new file mode 100644 index 0000000..95b5099 --- /dev/null +++ b/app/models.py @@ -0,0 +1,231 @@ +import sqlalchemy.types as types +from typing import Optional +from cryptography.fernet import Fernet +from sqlalchemy import ( + Boolean, + Column, + DateTime, + Integer, + String, + UniqueConstraint, + JSON, +) +from sqlalchemy.orm import Session, relationship +from sqlalchemy.sql.schema import ForeignKey + +from app import schemas +from app.common_tags import SYNC_FAILURES_METADATA, JOB_ID_METADATA +from app.database import Base +from app.settings import settings +from app.libs.s3.client import S3Client + + +class ChoiceType(types.TypeDecorator): + """ + ChoiceField Implementation for SQL Alchemy + + Credits: https://stackoverflow.com/a/6264027 + """ + + impl = types.String + + def __init__(self, enum, **kwargs): + self.choices = enum + super(ChoiceType, self).__init__(**kwargs) + + def process_bind_param(self, value, dialect): + for member in dir(self.choices): + if getattr(self.choices, member) == value: + return member + + def process_result_value(self, value, dialect): + return getattr(self.choices, value).value + + +class ModelMixin(object): + @classmethod + def get(cls, db: Session, object_id: int): + return db.query(cls).filter(cls.id == object_id).first() + + @classmethod + def get_all(cls, db: Session, skip: int = 0, limit: int = 100): + return db.query(cls).offset(skip).limit(limit).all() + + @classmethod + def delete(cls, db: Session, object_id: int): + return ( + db.query(cls) + .filter(cls.id == object_id) + .delete(synchronize_session="fetch") + ) + + +class EncryptionMixin(object): + @classmethod + def _get_encryption_key(cls): + return Fernet(settings.secret_key) + + @classmethod + def encrypt_value(cls, raw_value): + key = cls._get_encryption_key() + return key.encrypt(raw_value.encode("utf-8")).decode("utf-8") + + @classmethod + def decrypt_value(cls, encrypted_value): + key = cls._get_encryption_key() + return key.decrypt(encrypted_value.encode("utf-8")).decode("utf-8") + + +class HyperFile(ModelMixin, Base): + __tablename__ = "hyper_file" + __table_args__ = (UniqueConstraint("user", "form_id", name="_user_form_id_uc"),) + + id = Column(Integer, primary_key=True, index=True) + filename = Column(String, unique=False) + user = Column(Integer, ForeignKey("user.id", ondelete="CASCADE")) + form_id = Column(Integer, nullable=False) + last_updated = Column(DateTime) + last_synced = Column(DateTime) + is_active = Column(Boolean, default=True) + file_status = Column( + ChoiceType(schemas.FileStatusEnum), + default=schemas.FileStatusEnum.file_unavailable, + ) + configuration_id = Column( + Integer, ForeignKey("configuration.id", ondelete="SET NULL") + ) + meta_data = Column(JSON, default={SYNC_FAILURES_METADATA: 0, JOB_ID_METADATA: ""}) + configuration = relationship("Configuration") + + def get_file_path(self, db: Session): + user = User.get(db, self.user) + s3_path = f"{user.server}/{user.username}/{self.form_id}_{self.filename}" + return s3_path + + def retrieve_latest_file(self, db: Session): + local_path = f"{settings.media_path}/{self.form_id}_{self.filename}" + s3_path = self.get_file_path(db) + client = S3Client() + client.download(local_path, s3_path) + return local_path + + @classmethod + def get_using_file_create(cls, db: Session, file_create: schemas.FileCreate): + return ( + db.query(cls) + .filter(cls.user == file_create.user, cls.form_id == file_create.form_id) + .first() + ) + + @classmethod + def get_active_files(cls, db: Session): + return db.query(cls).filter(cls.is_active == True).all() # noqa + + @classmethod + def create(cls, db: Session, hyperfile: schemas.FileCreate): + instance = cls(**hyperfile.dict()) + db.add(instance) + db.commit() + db.refresh(instance) + return instance + + @classmethod + def filter(cls, user: schemas.User, form_id: int, db: Session): + return db.query(cls).filter(cls.user == user.id, cls.form_id == form_id).all() + + +class Server(ModelMixin, EncryptionMixin, Base): + __tablename__ = "server" + + id = Column(Integer, primary_key=True) + url = Column(String, unique=True) + client_id = Column(String) + client_secret = Column(String) + + @classmethod + def get_using_url(cls, db: Session, url: str) -> Optional[schemas.Server]: + return db.query(cls).filter(cls.url == url).first() + + @classmethod + def create(cls, db: Session, server: schemas.ServerCreate) -> schemas.Server: + encrypted_secret = cls.encrypt_value(server.client_secret) + server = cls( + url=server.url, + client_id=server.client_id, + client_secret=encrypted_secret, + ) + db.add(server) + db.commit() + db.refresh(server) + return server + + +class User(ModelMixin, EncryptionMixin, Base): + __tablename__ = "user" + __table_args__ = (UniqueConstraint("server", "username", name="_server_user_uc"),) + + id = Column(Integer, primary_key=True, index=True) + username = Column(String) + refresh_token = Column(String) + server = Column(Integer, ForeignKey("server.id", ondelete="CASCADE")) + files = relationship("HyperFile") + + @classmethod + def get_using_username(cls, db: Session, username: str): + return db.query(cls).filter(cls.username == username).first() + + @classmethod + def get_using_server_and_username(cls, db: Session, username: str, server_id: int): + return ( + db.query(cls) + .filter(cls.username == username, cls.server == server_id) + .first() + ) + + @classmethod + def create(cls, db: Session, user: schemas.User): + encrypted_token = cls.encrypt_value(user.refresh_token) + user = cls( + username=user.username, refresh_token=encrypted_token, server=user.server + ) + db.add(user) + db.commit() + db.refresh(user) + return user + + +class Configuration(ModelMixin, EncryptionMixin, Base): + """ + Tableau server authentication configurations; Used to publish + Hyper files. + """ + + __tablename__ = "configuration" + __table_args__ = ( + UniqueConstraint( + "server_address", "token_name", "user", name="_server_token_name_uc" + ), + ) + + id = Column(Integer, primary_key=True, index=True) + server_address = Column(String) + site_name = Column(String) + token_name = Column(String) + token_value = Column(String) + project_name = Column(String, default="default") + user = Column(Integer, ForeignKey("user.id", ondelete="CASCADE")) + + @classmethod + def filter_using_user_id(cls, db: Session, user_id: int): + return db.query(cls).filter(cls.user == user_id) + + @classmethod + def create(cls, db: Session, config: schemas.ConfigurationCreate): + encrypted_token = cls.encrypt_value(config.token_value) + data = config.dict() + data.update({"token_value": encrypted_token}) + configuration = cls(**data) + db.add(configuration) + db.commit() + db.refresh(configuration) + return configuration diff --git a/app/routers/__init__.py b/app/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/routers/configuration.py b/app/routers/configuration.py new file mode 100644 index 0000000..1f4ba17 --- /dev/null +++ b/app/routers/configuration.py @@ -0,0 +1,142 @@ +# Routes for the Tableau Configuration (/configurations) endpoint +from typing import List + +from fastapi import Depends +from fastapi.exceptions import HTTPException +from fastapi.requests import Request +from fastapi.routing import APIRouter +from sqlalchemy.orm import Session +from psycopg2.errors import UniqueViolation +from sqlalchemy.exc import IntegrityError + +from app import schemas +from app.models import Configuration, User +from app.utils.auth_utils import IsAuthenticatedUser +from app.utils.utils import get_db + + +router = APIRouter() + + +@router.get( + "/api/v1/configurations", + status_code=200, + response_model=List[schemas.ConfigurationListResponse], +) +def list_configurations( + request: Request, + user: User = Depends(IsAuthenticatedUser()), + db: Session = Depends(get_db), +): + """ + Lists out all the Tableau Configurations currently accessible for to the logged in user + """ + resp = [] + configurations = Configuration.filter_using_user_id(db, user.id) + + for config in configurations: + config = schemas.ConfigurationListResponse.from_orm(config) + config.url = f"{request.base_url.scheme}://{request.base_url.netloc}" + config.url += router.url_path_for("get_configuration", config_id=config.id) + resp.append(config) + return resp + + +@router.get( + "/api/v1/configurations/{config_id}", + status_code=200, + response_model=schemas.ConfigurationResponse, +) +def get_configuration( + config_id: int, + user: User = Depends(IsAuthenticatedUser()), + db: Session = Depends(get_db), +): + """ + Retrieve a specific configuration + """ + config = Configuration.get(db, config_id) + + if config and config.user == user.id: + return config + else: + raise HTTPException(status_code=404, detail="Tableau configuration not found.") + + +@router.post( + "/api/v1/configurations", + status_code=201, + response_model=schemas.ConfigurationResponse, +) +def create_configuration( + config_data: schemas.ConfigurationCreateRequest, + user: User = Depends(IsAuthenticatedUser()), + db: Session = Depends(get_db), +): + """ + Create a new Tableau Server Configuration that can be attached + to a hyper file to define where the hyper file should be pushed to. + """ + config_data = schemas.ConfigurationCreate( + user=user.id, + server_address=config_data.server_address, + site_name=config_data.site_name, + token_name=config_data.token_name, + token_value=config_data.token_value, + project_name=config_data.project_name, + ) + try: + config = Configuration.create(db, config_data) + return config + except (UniqueViolation, IntegrityError): + raise HTTPException(status_code=400, detail="Configuration already exists") + + +@router.patch( + "/api/v1/configurations/{config_id}", + status_code=200, + response_model=schemas.ConfigurationResponse, +) +def patch_configuration( + config_id: int, + config_data: schemas.ConfigurationPatchRequest, + user: User = Depends(IsAuthenticatedUser()), + db: Session = Depends(get_db), +): + """ + Partially update a Configuration + """ + config = Configuration.get(db, config_id) + + if config and config.user == user.id: + try: + for key, value in config_data.dict().items(): + if value: + if key == "token_value": + value = Configuration.encrypt_value(value) + setattr(config, key, value) + db.commit() + db.refresh(config) + return config + except (UniqueViolation, IntegrityError): + raise HTTPException(status_code=400, detail="Configuration already exists") + else: + raise HTTPException(404, detail="Tableau Configuration not found.") + + +@router.delete("/api/v1/configurations/{config_id}", status_code=204) +def delete_configuration( + config_id: int, + user: User = Depends(IsAuthenticatedUser()), + db: Session = Depends(get_db), +): + """ + Permanently delete a configuration + """ + config = Configuration.get(db, config_id) + + if config and config.user == user.id: + Configuration.delete(db, config.id) + db.commit() + else: + raise HTTPException(status_code=400) diff --git a/app/routers/file.py b/app/routers/file.py new file mode 100644 index 0000000..ad5ed4e --- /dev/null +++ b/app/routers/file.py @@ -0,0 +1,309 @@ +# Routes for the Hyperfile (/files) endpoint +import os +import shutil +from pathlib import Path +from tempfile import NamedTemporaryFile +from datetime import datetime, timedelta +from typing import List, Optional, Union + +from fastapi import BackgroundTasks, Depends, HTTPException, UploadFile, File, Request +from fastapi.routing import APIRouter +from fastapi.responses import FileResponse, JSONResponse +from fastapi_cache import caches +from redis.client import Redis +from sqlalchemy.orm import Session +from tableauhyperapi.hyperprocess import HyperProcess + +from app import schemas +from app.common_tags import HYPER_PROCESS_CACHE_KEY +from app.libs.s3.client import S3Client +from app.models import HyperFile, Configuration, User +from app.settings import settings +from app.utils.auth_utils import IsAuthenticatedUser +from app.utils.utils import get_db, get_redis_client +from app.utils.hyper_utils import handle_csv_import +from app.utils.onadata_utils import ( + ConnectionRequestError, + DoesNotExist, + UnsupportedForm, + create_or_get_hyperfile, + start_csv_import_to_hyper, + schedule_hyper_file_cron_job, + start_csv_import_to_hyper_job, +) + + +router = APIRouter() + + +def _create_hyper_file_response( + hyper_file: HyperFile, db: Session, request: Request +) -> schemas.FileResponseBody: + from app.main import app + + s3_client = S3Client() + file_path = hyper_file.get_file_path(db) + download_url = s3_client.generate_presigned_download_url( + file_path, expiration=settings.download_url_lifetime + ) + data = schemas.File.from_orm(hyper_file).dict() + if download_url: + expiry_date = datetime.utcnow() + timedelta( + seconds=settings.download_url_lifetime + ) + data.update( + { + "download_url": download_url, + "download_url_valid_till": expiry_date.isoformat(), + } + ) + if hyper_file.configuration_id: + config_url = f"{request.base_url.scheme}://{request.base_url.netloc}" + config_url += app.url_path_for( + "get_configuration", config_id=hyper_file.configuration_id + ) + data.update({"configuration_url": config_url}) + response = schemas.FileResponseBody(**data) + return response + + +@router.post("/api/v1/files", status_code=201, response_model=schemas.FileResponseBody) +def create_hyper_file( + request: Request, + file_request: schemas.FileRequestBody, + background_tasks: BackgroundTasks, + user: User = Depends(IsAuthenticatedUser()), + db: Session = Depends(get_db), + redis_client: Redis = Depends(get_redis_client), +): + """ + Creates a Hyper file object. + + JSON Data Parameters: + - `form_id`: An integer representing the ID of the form whose data should be exported + into a Hyperfile & tracked. + - `sync_immediately`: An optional boolean field that determines whether a forms data should + be synced immediately after creation of a Hyper file object. _Note: Hyper files are updated + periodically on a schedule by default i.e 15 minutes after creation of object or every 24 hours_ + - `configuration_id`: An integer representing the ID of a Configuration(_See docs on /api/v1/configurations route_). + Determines where the hyper file is pushed to after it has been updated with the latest form data. + """ + process: HyperProcess = caches.get(HYPER_PROCESS_CACHE_KEY) + configuration = None + try: + file_data = schemas.FileCreate(form_id=file_request.form_id, user=user.id) + if file_request.configuration_id: + configuration = Configuration.get(db, file_request.configuration_id) + if not configuration or not configuration.user == user.id: + raise HTTPException( + status_code=400, + detail=f"Tableau configuration with ID {file_request.configuration_id} not found", + ) + file_instance, created = create_or_get_hyperfile(db, file_data, process) + except (DoesNotExist, UnsupportedForm) as e: + raise HTTPException(status_code=400, detail=str(e)) + except ConnectionRequestError as e: + raise HTTPException(status_code=502, detail=str(e)) + else: + if not created: + raise HTTPException(status_code=400, detail="File already exists.") + + if configuration: + file_instance.configuration_id = configuration.id + + if file_request.sync_immediately: + background_tasks.add_task( + start_csv_import_to_hyper, file_instance.id, process + ) + background_tasks.add_task( + schedule_hyper_file_cron_job, + start_csv_import_to_hyper_job, + file_instance.id, + ) + file_instance.file_status = schemas.FileStatusEnum.queued.value + db.commit() + return _create_hyper_file_response(file_instance, db, request) + + +@router.get("/api/v1/files", response_model=List[schemas.FileListItem]) +def list_hyper_files( + request: Request, + user: User = Depends(IsAuthenticatedUser()), + form_id: Optional[str] = None, + db: Session = Depends(get_db), +): + """ + This endpoint lists out all the hyper files currently owned by the + logged in user. + + Query Parameters: + - `form_id`: An integer representing an ID of a form on the users authenticated + server. + """ + response = [] + hyperfiles = [] + if form_id: + hyperfiles = HyperFile.filter(user, form_id, db) + else: + hyperfiles = user.files + + for hyperfile in hyperfiles: + url = request.base_url.scheme + "://" + request.base_url.netloc + url += router.url_path_for("get_hyper_file", file_id=hyperfile.id) + response.append( + schemas.FileListItem( + url=url, + id=hyperfile.id, + form_id=hyperfile.form_id, + filename=hyperfile.filename, + file_status=hyperfile.file_status, + ) + ) + return response + + +@router.get("/api/v1/files/{file_id}", response_model=schemas.FileResponseBody) +def get_hyper_file( + file_id: Union[str, int], + request: Request, + user: User = Depends(IsAuthenticatedUser()), + db: Session = Depends(get_db), +): + """ + Retrieves a specific hyper file. _This endpoint supports both `.json` and `.hyper` response_ + + The `.json` response provides the JSON representation of the hyper file object. While the `.hyper` + response provides a FileResponse that contains the latest hyper file download. + """ + response_type = None + file_parts = file_id.split(".") + if len(file_parts) == 2: + file_id, response_type = file_parts + + try: + file_id = int(file_id) + except ValueError: + raise HTTPException(status_code=400, detail="Invalid file ID") + + hyperfile = HyperFile.get(db, file_id) + + if hyperfile and user.id == hyperfile.user: + if not response_type or response_type == "json": + return _create_hyper_file_response(hyperfile, db, request) + elif response_type == "hyper": + file_path = hyperfile.retrieve_latest_file(db) + if os.path.exists(file_path): + return FileResponse(file_path, filename=hyperfile.filename) + else: + raise HTTPException( + status_code=404, detail="File currently not available" + ) + else: + raise HTTPException(status_code=400, detail="Unsupported content type") + else: + raise HTTPException(status_code=404, detail="File not found") + + +@router.patch( + "/api/v1/files/{file_id}", status_code=200, response_model=schemas.FileResponseBody +) +def patch_hyper_file( + file_id: int, + request: Request, + data: schemas.FilePatchRequestBody, + user: User = Depends(IsAuthenticatedUser()), + db: Session = Depends(get_db), +): + """ + Partially updates a specific hyper file object + """ + hyper_file = HyperFile.get(db, file_id) + + if not hyper_file or hyper_file.user != user.id: + raise HTTPException(status_code=404, detail="File not found") + + configuration = Configuration.get(db, data.configuration_id) + if not configuration or not configuration.user == user.id: + raise HTTPException( + status_code=400, + detail=f"Tableau configuration with ID {data.configuration_id} not found", + ) + hyper_file.configuration_id = configuration.id + db.commit() + db.refresh(hyper_file) + return _create_hyper_file_response(hyper_file, db, request) + + +@router.post("/api/v1/files/csv_import", status_code=200, response_class=FileResponse) +def import_data(id_string: str, csv_file: UploadFile = File(...)): + """ + Experimental Endpoint: Creates and imports `csv_file` data into a hyper file. + """ + process: HyperProcess = caches.get(HYPER_PROCESS_CACHE_KEY) + suffix = Path(csv_file.filename).suffix + csv_file.file.seek(0) + file_path = f"{settings.media_path}/{id_string}.hyper" + with NamedTemporaryFile(delete=False, suffix=suffix) as tmp_upload: + shutil.copyfileobj(csv_file.file, tmp_upload) + tmp_upload.flush() + handle_csv_import( + file_path=file_path, csv_path=Path(tmp_upload.name), process=process + ) + return FileResponse(file_path, filename=f"{id_string}.hyper") + + +@router.delete("/api/v1/files/{file_id}", status_code=204) +def delete_hyper_file( + file_id: int, + user: User = Depends(IsAuthenticatedUser()), + db: Session = Depends(get_db), +): + """ + Permanently delete a Hyper File Object + """ + hyper_file = HyperFile.get(db, file_id) + + if hyper_file and hyper_file.user == user.id: + # Delete file from S3 + s3_client = S3Client() + if s3_client.delete(hyper_file.get_file_path(db)): + # Delete Hyper File object from database + HyperFile.delete(db, file_id) + db.commit() + else: + raise HTTPException(status_code=400) + + +@router.post("/api/v1/files/{file_id}/sync") +def trigger_hyper_file_sync( + request: Request, + file_id: int, + background_tasks: BackgroundTasks, + db: Session = Depends(get_db), + user: User = Depends(IsAuthenticatedUser()), + redis_client: Redis = Depends(get_redis_client), +): + """ + Trigger Hyper file sync; Starts a process that updates the + hyper files data. + """ + hyper_file = HyperFile.get(db, file_id) + + if not hyper_file: + raise HTTPException(404, "File not found.") + if hyper_file.user == user.id: + status_code = 200 + if hyper_file.file_status not in [ + schemas.FileStatusEnum.queued, + schemas.FileStatusEnum.syncing, + ]: + process: HyperProcess = caches.get(HYPER_PROCESS_CACHE_KEY) + background_tasks.add_task(start_csv_import_to_hyper, hyper_file.id, process) + else: + status_code = 202 + + return JSONResponse( + {"message": "File syncing is currently on-going"}, status_code=status_code + ) + else: + raise HTTPException(401) diff --git a/app/routers/oauth.py b/app/routers/oauth.py new file mode 100644 index 0000000..1db3ffa --- /dev/null +++ b/app/routers/oauth.py @@ -0,0 +1,149 @@ +# Routes for the OAuth (/oauth) endpoint +import json +import uuid +from typing import Optional + +import httpx +import redis +from fastapi import Depends, HTTPException, Request +from fastapi.responses import JSONResponse, RedirectResponse +from fastapi.routing import APIRouter + +from app import schemas +from app.common_tags import ONADATA_TOKEN_ENDPOINT, ONADATA_USER_ENDPOINT +from app.models import User, Server +from app.utils.utils import get_db, get_redis_client +from app.utils.auth_utils import create_session, IsAuthenticatedUser + +router = APIRouter() + + +@router.get("/api/v1/oauth/login", status_code=302) +def start_login_flow( + server_url: str, + redirect_url: Optional[str] = None, + user=Depends(IsAuthenticatedUser(raise_errors=False)), + db=Depends(get_db), + redis: redis.Redis = Depends(get_redis_client), +): + """ + Starts OAuth2 Code Flow; The flow authenticates a user against one of the configured + servers. _For more info on server configurations check the `/api/v1/server` docs_ + + This endpoint redirects the client to the `server_url` for authentication if the server + has a server configuration in the system. Once the user is authorized on the server + the user should be redirected back to `/api/v1/oauth/callback` which will handle + creation of a user session that will allow the user to access the applications Hyper File + resources. + """ + if not user: + server: Optional[schemas.Server] = Server.get_using_url(db, server_url) + if not server: + raise HTTPException(status_code=400, detail="Server not configured") + auth_state = {"server_id": server.id} + if redirect_url: + auth_state["redirect_url"] = redirect_url + + state_key = str(uuid.uuid4()) + redis.setex(state_key, 600, json.dumps(auth_state)) + url = f"{server.url}/o/authorize?client_id={server.client_id}&response_type=code&state={state_key}" + return RedirectResponse( + url=url, + status_code=302, + headers={ + "Cache-Control": "no-cache, no-store, revalidate", + }, + ) + else: + return RedirectResponse(url=redirect_url or "/", status_code=302) + + +@router.get( + "/api/v1/oauth/callback", + status_code=302, + responses={200: {"model": schemas.UserBearerTokenResponse}}, +) +def handle_oauth_callback( + code: str, + state: str, + request: Request, + db=Depends(get_db), + user=Depends(IsAuthenticatedUser(raise_errors=False)), + redis: redis.Redis = Depends(get_redis_client), +): + """ + Handles OAuth2 Code flow callback. This url should be registered + as the "redirect_uri" for your Server OAuth Application(Onadata). + + This endpoint creates a user session for the authorized user and authenticates + the user granting them access to the Hyper File API. + + User sessions last for 2 weeks. After the 2 weeks pass the user needs to re-authorize + with the application to gain access to the Hyper file API + """ + if user: + return RedirectResponse(url="/", status_code=302) + + auth_state = redis.get(state) + if not auth_state: + raise HTTPException( + status_code=401, detail="Authorization state can not be confirmed." + ) + + auth_state = json.loads(auth_state) + redis.delete(state) + server: Optional[schemas.Server] = Server.get( + db, object_id=auth_state.get("server_id") + ) + redirect_url = auth_state.get("redirect_url") + data = { + "grant_type": "authorization_code", + "code": code, + "client_id": server.client_id, + } + url = f"{server.url}{ONADATA_TOKEN_ENDPOINT}" + resp = httpx.post( + url, + data=data, + auth=( + server.client_id, + Server.decrypt_value(server.client_secret), + ), + ) + + if resp.status_code == 200: + resp = resp.json() + access_token = resp.get("access_token") + refresh_token = resp.get("refresh_token") + + user_url = f"{server.url}{ONADATA_USER_ENDPOINT}.json" + headers = {"Authorization": f"Bearer {access_token}"} + resp = httpx.get(user_url, headers=headers) + if resp.status_code == 200: + resp = resp.json() + username = resp.get("username") + user = User.get_using_server_and_username(db, username, server.id) + if not user: + user_data = schemas.User( + username=username, refresh_token=refresh_token, server=server.id + ) + user = User.create(db, user_data) + else: + user.refresh_token = User.encrypt_value(refresh_token) + db.commit() + + request, session_data = create_session(user, redis, request) + if redirect_url: + return RedirectResponse( + redirect_url, + status_code=302, + headers={ + "Cache-Control": "no-cache, no-store, revalidate", + }, + ) + return JSONResponse( + schemas.UserBearerTokenResponse( + bearer_token=session_data.decode("utf-8") + ).dict() + ) + raise HTTPException(status_code=401, detail="Authentication failed.") diff --git a/app/routers/server.py b/app/routers/server.py new file mode 100644 index 0000000..04e9781 --- /dev/null +++ b/app/routers/server.py @@ -0,0 +1,63 @@ +# Routes for the Server (/server) endpoint +from typing import List + +from fastapi import Depends, HTTPException +from fastapi.routing import APIRouter +from starlette.datastructures import URL + +from app import schemas +from app.models import Server +from app.utils.utils import get_db + + +router = APIRouter() + + +@router.post("/api/v1/servers", response_model=schemas.ServerResponse, status_code=201) +def create_server_object(server: schemas.ServerCreate, db=Depends(get_db)): + """ + Create new Server configuration objects. + + Server configuration objects are used to authorize the + Duva Application against an OnaData server; Users authorize + against a server configuration. + + After creation of a server object, users & 3rd party applications + can utilize the OAuth login route with the server url as the `server_url` query param to authorize users and enable the application to pull & sync forms that the user has access to. + """ + url = URL(server.url) + if not url.scheme or not url.netloc: + raise HTTPException(status_code=400, detail=f"Invalid url {server.url}") + server.url = f"{url.scheme}://{url.netloc}" + if Server.get_using_url(db, server.url): + raise HTTPException( + status_code=400, detail=f"Server with url '{server.url}' already exists." + ) + server = Server.create(db, server) + return server + + +@router.get( + "/api/v1/servers/{obj_id}", + response_model=schemas.ServerResponse, +) +def retrieve_server(obj_id: int, db=Depends(get_db)): + """ + Retrieve a specific server configuration + """ + server = Server.get(db=db, object_id=obj_id) + if not server: + raise HTTPException( + status_code=404, + detail=f"Server configuration with ID {obj_id} can not be found.", + ) + return server + + +@router.get("/api/v1/servers", response_model=List[schemas.ServerResponse]) +def list_servers(db=Depends(get_db)): + """ + List all servers configured to work with the application that users can authorize against. + """ + servers = Server.get_all(db) + return servers diff --git a/app/schemas.py b/app/schemas.py new file mode 100644 index 0000000..3484d1d --- /dev/null +++ b/app/schemas.py @@ -0,0 +1,157 @@ +# Schema Definitions +from enum import Enum + +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel +from app.common_tags import SYNC_FAILURES_METADATA, JOB_ID_METADATA + + +class FileStatusEnum(str, Enum): + queued = "Sync Queued" + syncing = "Syncing file" + latest_sync_failed = "Latest Sync Failed" + file_available = "File available" + file_unavailable = "File unavailable" + + +class ServerBase(BaseModel): + url: str + + +class ServerResponse(BaseModel): + id: int + url: str + + class Config: + orm_mode = True + + +class ServerCreate(ServerBase): + client_id: str + client_secret: str + + +class Server(ServerCreate): + id: int + + class Config: + orm_mode = True + + +class User(BaseModel): + username: str + refresh_token: str + server: int + + +class FileBase(BaseModel): + form_id: int + + +class FileListItem(BaseModel): + url: str + id: int + form_id: int + filename: str + file_status: FileStatusEnum = FileStatusEnum.file_unavailable.value + + +class FileCreate(FileBase): + user: int + filename: Optional[str] + is_active: bool = True + meta_data: dict = {SYNC_FAILURES_METADATA: 0, JOB_ID_METADATA: ""} + + +class ConfigurationResponse(BaseModel): + id: int + server_address: str + site_name: str + token_name: str + project_name: str + + class Config: + orm_mode = True + + +class ConfigurationListResponse(BaseModel): + url: Optional[str] + id: int + site_name: str + token_name: str + project_name: str + + class Config: + orm_mode = True + + +class ConfigurationCreateRequest(BaseModel): + server_address: str + site_name: str + token_name: str + project_name: str + token_value: str + + +class ConfigurationPatchRequest(BaseModel): + server_address: Optional[str] + site_name: Optional[str] + token_name: Optional[str] + project_name: Optional[str] + + +class ConfigurationCreate(ConfigurationCreateRequest): + user: int + + +class Configuration(ConfigurationCreate): + id: int + + class Config: + orm_mode = True + + +class File(FileCreate): + id: int + file_status: FileStatusEnum = FileStatusEnum.file_unavailable.value + last_updated: Optional[datetime] = None + last_synced: Optional[datetime] = None + configuration: Optional[Configuration] = None + + class Config: + orm_mode = True + + +class FileResponseBody(FileBase): + id: int + filename: str + file_status: FileStatusEnum = FileStatusEnum.file_unavailable.value + last_updated: Optional[datetime] = None + last_synced: Optional[datetime] = None + download_url: Optional[str] + download_url_valid_till: Optional[str] + configuration_url: Optional[str] + + class Config: + orm_mode = True + + +class FilePatchRequestBody(BaseModel): + configuration_id: int + + +class UserBearerTokenResponse(BaseModel): + bearer_token: str + + +class FileRequestBody(FileBase): + sync_immediately: bool = False + configuration_id: Optional[int] + + +class EventResponse(BaseModel): + status: Optional[str] = "" + name: Optional[str] = "" + object_url: Optional[str] = "" diff --git a/app/settings.py b/app/settings.py new file mode 100644 index 0000000..3d451de --- /dev/null +++ b/app/settings.py @@ -0,0 +1,49 @@ +from pydantic import BaseSettings + + +class Settings(BaseSettings): + app_name: str = "Duva" + app_description: str = "" + app_version: str = "0.0.1" + app_host: str = "127.0.0.1" + app_port: int = 8000 + database_url: str = "sqlite:///./sqllite_db.db" + debug: bool = True + sentry_dsn: str = "" + session_same_site: str = "none" + # How long force update is locked after an update is completed + force_update_cooldown: int = 10 + # How long download URLs for HyperFiles should last + download_url_lifetime: int = 3600 + # Generate secret key using: + # dd if=/dev/urandom bs=32 count=1 2>/dev/null | openssl base64 + secret_key: str = "xLLwpyLgT0YumXu77iDYX+HDVBX6djFFVbAWPhhHAHY=" + enable_secure_sessions: bool = False + # check_same_thread: False is only needed for SQLite + # https://fastapi.tiangolo.com/tutorial/sql-databases/#note + db_connect_args: dict = {} + media_path: str = "/app/media" + s3_bucket: str = "hypermind-mvp" + s3_region: str = "eu-west-1" + # For more on tokens, head here: + # https://help.tableau.com/current/server/en-us/security_personal_access_tokens.htm + tableau_server_address: str = "" + tableau_site_name: str = "" + tableau_token_name: str = "" + tableau_token_value: str = "" + redis_url: str = "redis://cache" + redis_host: str = "cache" + redis_port: int = 6379 + redis_db: int = 0 + # CORS Configuration + cors_allowed_origins: list = ["http://localhost:3000", "http://localhost:8000"] + cors_allow_credentials: bool = True + cors_allowed_methods: list = ["*"] + cors_allowed_headers: list = ["*"] + cors_max_age: int = -1 + # HyperFile job settings + job_failure_limit: int = 5 + schedule_all_active: bool = False + + +settings = Settings() diff --git a/app/tests/conftest.py b/app/tests/conftest.py new file mode 100644 index 0000000..3e72571 --- /dev/null +++ b/app/tests/conftest.py @@ -0,0 +1,64 @@ +import fakeredis +import pytest + +from app.main import app +from app.tests.test_base import TestingSessionLocal +from app.utils.utils import get_db, get_redis_client + + +TEST_REDIS_SERVER = fakeredis.FakeServer() + + +def override_get_db(): + try: + db = TestingSessionLocal() + yield db + finally: + db.close() + + +def override_get_redis_client(): + redis_client = fakeredis.FakeRedis(server=TEST_REDIS_SERVER) + try: + yield redis_client + finally: + pass + + +app.dependency_overrides[get_db] = override_get_db +app.dependency_overrides[get_redis_client] = override_get_redis_client + + +@pytest.fixture(scope="function") +def create_user_and_login(): + from app import schemas + from app.models import User, Server + from app.utils.auth_utils import create_session + + db = TestingSessionLocal() + redis_client = fakeredis.FakeRedis(server=TEST_REDIS_SERVER) + server = Server.create( + db, + schemas.ServerCreate( + url="http://testserver", + client_id="some_client_id", + client_secret="some_secret_value", + ), + ) + if User.get_using_username(db, "bob"): + db.query(User).filter(User.username == "bob").delete() + db.commit() + + user = User.create( + db, + schemas.User(username="bob", refresh_token="somes3cr3tvalu3", server=server.id), + ) + + _, bearer_token = create_session(user, redis_client) + yield user, bearer_token + + # Clean up created objects + db.query(User).filter(User.id == user.id).delete() + db.query(Server).filter(Server.id == server.id).delete() + db.commit() + db.close() diff --git a/app/tests/routes/test_configuration.py b/app/tests/routes/test_configuration.py new file mode 100644 index 0000000..54e78ab --- /dev/null +++ b/app/tests/routes/test_configuration.py @@ -0,0 +1,101 @@ +from app.models import Configuration +from app import schemas +from app.tests.test_base import TestBase + + +class TestConfiguration(TestBase): + def _create_configuration(self, auth_credentials: dict, config_data: dict = None): + config_data = ( + config_data + or schemas.ConfigurationCreateRequest( + site_name="test", + server_address="http://test", + token_name="test", + token_value="test", + project_name="default", + ).dict() + ) + response = self.client.post( + "/api/v1/configurations", json=config_data, headers=auth_credentials + ) + + # Returns a 400 exception when configuration already exists + resp = self.client.post( + "/api/v1/configurations", json=config_data, headers=auth_credentials + ) + assert resp.status_code == 400 + assert resp.json() == {"detail": "Configuration already exists"} + return response + + def _cleanup_configs(self): + self.db.query(Configuration).delete() + self.db.commit() + + def test_create_retrieve_config(self, create_user_and_login): + _, jwt = create_user_and_login + jwt = jwt.decode("utf-8") + auth_credentials = {"Authorization": f"Bearer {jwt}"} + response = self._create_configuration(auth_credentials) + + assert response.status_code == 201 + config_id = response.json().get("id") + expected_data = schemas.ConfigurationResponse( + site_name="test", + server_address="http://test", + token_name="test", + project_name="default", + id=config_id, + ).dict() + assert response.json() == expected_data + + # Able to retrieve Tableau Configuration + response = self.client.get( + f"/api/v1/configurations/{config_id}", headers=auth_credentials + ) + assert response.status_code == 200 + assert response.json() == expected_data + self._cleanup_configs() + + def test_delete_config(self, create_user_and_login): + _, jwt = create_user_and_login + jwt = jwt.decode("utf-8") + auth_credentials = {"Authorization": f"Bearer {jwt}"} + response = self._create_configuration(auth_credentials) + assert response.status_code == 201 + config_id = response.json().get("id") + current_count = len(Configuration.get_all(self.db)) + + response = self.client.delete( + f"/api/v1/configurations/{config_id}", headers=auth_credentials + ) + assert response.status_code == 204 + assert len(Configuration.get_all(self.db)) == current_count - 1 + + def test_patch_config(self, create_user_and_login): + _, jwt = create_user_and_login + jwt = jwt.decode("utf-8") + auth_credentials = {"Authorization": f"Bearer {jwt}"} + response = self._create_configuration(auth_credentials) + + assert response.status_code == 201 + config_id = response.json().get("id") + data = schemas.ConfigurationPatchRequest( + site_name="test_change", + ).dict() + expected_data = schemas.ConfigurationResponse( + site_name="test_change", + server_address="http://test", + token_name="test", + project_name="default", + id=config_id, + ).dict() + + # Able to patch Tableau Configuration + response = self.client.patch( + f"/api/v1/configurations/{config_id}", + json=data, + headers=auth_credentials, + ) + assert response.status_code == 200 + assert response.json() == expected_data + self._cleanup_configs() diff --git a/app/tests/routes/test_file.py b/app/tests/routes/test_file.py new file mode 100644 index 0000000..781f69a --- /dev/null +++ b/app/tests/routes/test_file.py @@ -0,0 +1,344 @@ +from unittest.mock import patch + +from httpx._models import Response + +from app import schemas +from app.models import HyperFile, Configuration +from app.tests.test_base import TestBase + + +class TestFileRoute(TestBase): + @patch("app.routers.file.S3Client.generate_presigned_download_url") + @patch("app.routers.file.schedule_hyper_file_cron_job") + @patch("app.utils.onadata_utils.httpx.get") + @patch("app.utils.onadata_utils.get_access_token") + def _create_file( + self, + auth_credentials, + mock_access_token, + mock_get, + mock_schedule_form, + mock_presigned_create, + file_data: dict = None, + ): + mock_presigned_create.return_value = "https://testing.s3.amazonaws.com/1/bob/check_fields.hyper?AWSAccessKeyId=key&Signature=sig&Expires=1609838540" + mock_access_token.return_value = "some_access_token" + mock_schedule_form.return_value = True + mock_get.return_value = Response( + json={ + "url": "https://testserver/api/v1/forms/1", + "formid": 1, + "metadata": [], + "owner": "https://testserver/api/v1/users/bob", + "created_by": "https://testserver/api/v1/users/bob", + "public": True, + "public_data": True, + "public_key": "", + "require_auth": False, + "submission_count_for_today": 0, + "tags": [], + "title": "check_fields", + "users": [ + { + "is_org": False, + "metadata": {}, + "first_name": "Bob", + "last_name": "", + "user": "bob", + "role": "owner", + } + ], + "enketo_url": "https://enketo-stage.ona.io/x/Z7k6kqn9", + "enketo_preview_url": "https://enketo-stage.ona.io/preview/3eZVdQ26", + "enketo_single_submit_url": "https://enketo-stage.ona.io/x/Z7k6kqn9", + "num_of_submissions": 3, + "last_submission_time": "2020-11-16T15:38:28.779972+00:00", + "form_versions": [], + "data_views": [], + "description": "", + "downloadable": True, + "allows_sms": False, + "encrypted": False, + "sms_id_string": "check_fields", + "id_string": "check_fields", + "date_created": "2019-11-21T08:00:06.668073-05:00", + "date_modified": "2020-11-16T10:38:28.744440-05:00", + "uuid": "da3eed4893e74723b555f3255c432ae4", + "bamboo_dataset": "", + "instances_with_geopoints": True, + "instances_with_osm": False, + "version": "201911211300", + "has_hxl_support": False, + "last_updated_at": "2020-11-16T10:38:28.744455-05:00", + "hash": "md5:692e501a01879439dcec79399484de4f", + "is_merged_dataset": False, + "project": "https://testserver/api/v1/projects/500", + }, + status_code=200, + ) + + file_data = ( + file_data + or schemas.FileRequestBody( + server_url="http://testserver", form_id=1, immediate_sync=False + ).dict() + ) + response = self.client.post( + "/api/v1/files", json=file_data, headers=auth_credentials + ) + return response + + def _cleanup_files(self): + self.db.query(HyperFile).delete() + self.db.commit() + + def test_file_create(self, create_user_and_login): + _, jwt = create_user_and_login + num_of_files = len(HyperFile.get_all(self.db)) + jwt = jwt.decode("utf-8") + auth_credentials = {"Authorization": f"Bearer {jwt}"} + response = self._create_file(auth_credentials) + + assert response.status_code == 201 + assert len(HyperFile.get_all(self.db)) == num_of_files + 1 + self._cleanup_files() + + @patch("app.routers.file.S3Client.generate_presigned_download_url") + def test_file_update(self, mock_presigned_create, create_user_and_login): + mock_presigned_create.return_value = "https://testing.s3.amazonaws.com/1/bob/check_fields.hyper?AWSAccessKeyId=key&Signature=sig&Expires=1609838540" + user, jwt = create_user_and_login + num_of_files = len(HyperFile.get_all(self.db)) + jwt = jwt.decode("utf-8") + auth_credentials = {"Authorization": f"Bearer {jwt}"} + response = self._create_file(auth_credentials) + + assert response.status_code == 201 + assert len(HyperFile.get_all(self.db)) == num_of_files + 1 + + file_id = response.json().get("id") + # Test fails with 400 if update with non-existant + # configuration + data = {"configuration_id": 10230} + response = self.client.patch( + f"/api/v1/files/{file_id}", json=data, headers=auth_credentials + ) + + assert response.status_code == 400 + assert response.json() == { + "detail": "Tableau configuration with ID 10230 not found" + } + + # Correctly updates tableau configuration + configuration = Configuration.create( + self.db, + schemas.ConfigurationCreate( + user=user.id, + server_address="http://testserver", + site_name="test", + token_name="test", + token_value="test", + project_name="test", + ), + ) + data = {"configuration_id": configuration.id} + response = self.client.patch( + f"/api/v1/files/{file_id}", json=data, headers=auth_credentials + ) + + assert response.status_code == 200 + assert ( + response.json().get("configuration_url") + == f"http://testserver/api/v1/configurations/{configuration.id}" + ) + self._cleanup_files() + + @patch("app.routers.file.S3Client.delete") + def test_file_delete(self, mock_s3_delete, create_user_and_login): + mock_s3_delete.return_value = True + _, jwt = create_user_and_login + num_of_files = len(HyperFile.get_all(self.db)) + jwt = jwt.decode("utf-8") + auth_credentials = {"Authorization": f"Bearer {jwt}"} + response = self._create_file(auth_credentials) + + assert response.status_code == 201 + assert len(HyperFile.get_all(self.db)) == num_of_files + 1 + num_of_files += 1 + file_id = response.json().get("id") + + response = self.client.delete( + f"/api/v1/files/{file_id}", headers=auth_credentials + ) + assert response.status_code == 204 + assert len(HyperFile.get_all(self.db)) == num_of_files - 1 + self._cleanup_files() + + @patch("app.routers.file.S3Client.generate_presigned_download_url") + def test_file_with_config(self, mock_presigned_create, create_user_and_login): + mock_presigned_create.return_value = "https://testing.s3.amazonaws.com/1/bob/check_fields.hyper?AWSAccessKeyId=key&Signature=sig&Expires=1609838540" + user, jwt = create_user_and_login + num_of_files = len(HyperFile.get_all(self.db)) + jwt = jwt.decode("utf-8") + auth_credentials = {"Authorization": f"Bearer {jwt}"} + config = Configuration.create( + self.db, + schemas.ConfigurationCreate( + user=user.id, + server_address="http://test", + site_name="test", + project_name="test", + token_name="test", + token_value="test", + ), + ) + file_data = schemas.FileRequestBody( + server_url="http://testserver", + form_id=1, + immediate_sync=False, + configuration_id=config.id, + ).dict() + response = self._create_file(auth_credentials, file_data=file_data) + response_json = response.json() + response_json.pop("download_url_valid_till") + expected_data = { + "download_url": "https://testing.s3.amazonaws.com/1/bob/check_fields.hyper?AWSAccessKeyId=key&Signature=sig&Expires=1609838540", + "filename": "check_fields.hyper", + "file_status": schemas.FileStatusEnum.queued.value, + "form_id": 1, + "id": 1, + "last_synced": None, + "last_updated": None, + "configuration_url": f"http://testserver/api/v1/configurations/{config.id}", + } + + assert response.status_code == 201 + assert response_json == expected_data + assert len(HyperFile.get_all(self.db)) == num_of_files + 1 + + # Able to change Tableau Server Config + config_2 = Configuration.create( + self.db, + schemas.ConfigurationCreate( + user=user.id, + server_address="http://tes2", + site_name="tes2t", + project_name="test2", + token_name="test2", + token_value="test2", + ), + ) + file_data = schemas.FilePatchRequestBody(configuration_id=config_2.id).dict() + response = self.client.patch( + "/api/v1/files/1", json=file_data, headers=auth_credentials + ) + assert response.status_code == 200 + response_json = response.json() + assert ( + response_json.get("configuration_url") + == f"http://testserver/api/v1/configurations/{config_2.id}" + ) + # Delete Tableau Configurations + self.db.query(Configuration).delete() + self.db.commit() + self._cleanup_files() + + def test_file_list(self, create_user_and_login): + user, jwt = create_user_and_login + jwt = jwt.decode("utf-8") + auth_credentials = {"Authorization": f"Bearer {jwt}"} + self._create_file(auth_credentials) + + response = self.client.get("/api/v1/files", headers=auth_credentials) + assert response.status_code == 200 + assert len(user.files) == 1 + assert len(response.json()) == len(user.files) + hyperfile = user.files[0] + expected_data = schemas.FileListItem( + url=f"http://testserver/api/v1/files/{hyperfile.id}", + id=hyperfile.id, + form_id=hyperfile.form_id, + filename=hyperfile.filename, + ).dict() + expected_data.update({"file_status": schemas.FileStatusEnum.queued.value}) + assert response.json()[0] == expected_data + + # Test filtering + response = self.client.get( + "/api/v1/files?form_id=000", headers=auth_credentials + ) + assert response.status_code == 200 + assert len(response.json()) == 0 + + response = self.client.get("/api/v1/files?form_id=1", headers=auth_credentials) + assert response.status_code == 200 + assert len(response.json()) == len(user.files) + + self._cleanup_files() + + @patch("app.routers.file.start_csv_import_to_hyper") + def test_trigger_hyper_file_sync( + self, mock_start_csv_import, create_user_and_login + ): + _, jwt = create_user_and_login + jwt = jwt.decode("utf-8") + auth_credentials = {"Authorization": f"Bearer {jwt}"} + response = self._create_file(auth_credentials) + + # User is able to trigger a force update + file_id = response.json().get("id") + + with patch("app.utils.utils.redis.Redis"): + response = self.client.post( + f"/api/v1/files/{file_id}/sync", headers=auth_credentials + ) + + assert response.status_code == 202 + expected_json = response.json() + update_count = mock_start_csv_import.call_count + + # Returns a 202 status_code when update is on-going + # and doesn't trigger another update + response = self.client.post( + f"/api/v1/files/{file_id}/sync", headers=auth_credentials + ) + assert response.status_code == 202 + assert update_count == mock_start_csv_import.call_count + assert response.json() == expected_json + self._cleanup_files() + + @patch("app.routers.file.S3Client.generate_presigned_download_url") + def test_file_get(self, mock_presigned_create, create_user_and_login): + mock_presigned_create.return_value = "https://testing.s3.amazonaws.com/1/bob/check_fields.hyper?AWSAccessKeyId=key&Signature=sig&Expires=1609838540" + user, jwt = create_user_and_login + jwt = jwt.decode("utf-8") + auth_credentials = {"Authorization": f"Bearer {jwt}"} + self._create_file(auth_credentials) + + hyperfile = user.files[0] + response = self.client.get( + f"/api/v1/files/{hyperfile.id}", headers=auth_credentials + ) + expected_keys = [ + "form_id", + "id", + "filename", + "file_status", + "last_updated", + "last_synced", + "download_url", + "download_url_valid_till", + "configuration_url", + ] + assert response.status_code == 200 + assert list(response.json().keys()) == expected_keys + assert response.json()["id"] == hyperfile.id + self._cleanup_files() + + def test_file_get_raises_error_on_invalid_id(self, create_user_and_login): + user, jwt = create_user_and_login + jwt = jwt.decode("utf-8") + auth_credentials = {"Authorization": f"Bearer {jwt}"} + + response = self.client.get("/api/v1/files/form_id=1", headers=auth_credentials) + assert response.status_code == 400 + assert response.json() == {"detail": "Invalid file ID"} diff --git a/app/tests/routes/test_oauth.py b/app/tests/routes/test_oauth.py new file mode 100644 index 0000000..da48935 --- /dev/null +++ b/app/tests/routes/test_oauth.py @@ -0,0 +1,134 @@ +from unittest.mock import patch + +from httpx._models import Response + +from app import schemas +from app.models import Server, User +from app.tests.test_base import TestBase + + +class TestOAuthRoute(TestBase): + def setup_class(cls): + super().setup_class() + cls.mock_server = Server.create( + cls.db, + schemas.ServerCreate( + url="http://testserver", + client_id="some_client_id", + client_secret="some_client_secret", + ), + ) + cls.mock_server_2 = Server.create( + cls.db, + schemas.ServerCreate( + url="http://dupli.testserver", + client_id="some_client_id", + client_secret="some_client_secret", + ), + ) + + def teardown_class(cls): + cls.db.query(Server).filter(Server.id == cls.mock_server.id).delete() + cls.db.query(Server).filter(Server.id == cls.mock_server_2.id).delete() + cls.db.commit() + super().teardown_class() + + @patch("app.routers.oauth.uuid.uuid4") + def test_oauth_login_redirects(self, mock_uuid): + """ + Test that the "oauth/login" route redirects + to the correct URL + """ + # Ensure a 400 is raised when a user tries + # to login to a server that isn't configured + response = self.client.get("/api/v1/oauth/login?server_url=http://testserve") + assert response.status_code == 400 + assert response.json() == {"detail": "Server not configured"} + + mock_uuid.return_value = "some_uuid" + response = self.client.get("/api/v1/oauth/login?server_url=http://testserver") + assert ( + response.url + == f"http://testserver/o/authorize?client_id={self.mock_server.client_id}&response_type=code&state=some_uuid" + ) + + @patch("app.routers.oauth.httpx") + @patch("app.routers.oauth.redis.Redis.get") + def test_oauth_callback(self, mock_redis_get, mock_httpx): + """ + Test that the OAuth2 callback URL confirms the + auth state and creates a User object + """ + mock_redis_get.return_value = None + url = "/api/v1/oauth/callback?code=some_code&state=some_uuid" + response = self.client.get(url) + assert response.status_code == 401 + assert response.json() == { + "detail": "Authorization state can not be confirmed." + } + + assert len(User.get_all(self.db)) == 0 + mock_auth_state = f'{{"server_id": {self.mock_server.id}}}' + mock_redis_get.return_value = mock_auth_state + mock_httpx.post.return_value = Response( + json={ + "access_token": "some_access_token", + "token_type": "Bearer", + "expires_in": 36000, + "refresh_token": "some_refresh_token", + "scope": "read write groups", + }, + status_code=200, + ) + mock_httpx.get.return_value = Response( + json={ + "api_token": "some_api_token", + "temp_token": "some_temp_token", + "city": "Nairobi", + "country": "Kenya", + "gravatar": "avatar.png", + "name": "Bob", + "email": "bob@user.com", + "organization": "", + "require_auth": False, + "twitter": "", + "url": "http://testserver/api/v1/profiles/bob", + "user": "http://testserver/api/v1/users/bob", + "username": "bob", + "website": "", + }, + status_code=200, + ) + self.redis_client.set("some_uuid", mock_auth_state) + response = self.client.get(url) + assert response.status_code == 200 + assert len(User.get_all(self.db)) == 1 + assert "bearer_token" in response.json().keys() + + # Create user from different server with same username + self.client.cookies.clear() + mock_auth_state = f'{{"server_id": {self.mock_server_2.id}}}' + mock_redis_get.return_value = mock_auth_state + mock_httpx.get.return_value = Response( + json={ + "api_token": "some_api_token", + "temp_token": "some_temp_token", + "city": "Nairobi", + "country": "Kenya", + "gravatar": "avatar.png", + "name": "Bob", + "email": "bob@user.com", + "organization": "", + "require_auth": False, + "twitter": "", + "url": "http://dupli.testserver/api/v1/profiles/bob", + "user": "http://dupli.testserver/api/v1/users/bob", + "username": "bob", + "website": "", + }, + status_code=200, + ) + self.redis_client.set("some_uuid", mock_auth_state) + response = self.client.get(url) + assert response.status_code == 200 + assert len(User.get_all(self.db)) == 2 diff --git a/app/tests/routes/test_server.py b/app/tests/routes/test_server.py new file mode 100644 index 0000000..ff5b082 --- /dev/null +++ b/app/tests/routes/test_server.py @@ -0,0 +1,55 @@ +from typing import Tuple +from fastapi.responses import Response +from app import schemas +from app.models import Server +from app.tests.test_base import TestBase + + +class TestServerRoute(TestBase): + def _create_server(self, url: str = "http://testserver") -> Tuple[Response, int]: + initial_count = len(Server.get_all(self.db)) + data = schemas.ServerCreate( + url=url, + client_id="some_client_id", + client_secret="some_client_secret", + ).dict() + response = self.client.post("/api/v1/servers", json=data) + return response, initial_count + + def _cleanup_server(self): + self.db.query(Server).filter(Server.url == "http://testserver").delete() + self.db.commit() + + def test_bad_url_rejected(self): + url = "bad_url" + response, _ = self._create_server(url=url) + assert response.status_code == 400 + assert response.json() == {"detail": f"Invalid url {url}"} + + def test_create_server(self): + response, initial_count = self._create_server() + expected_keys = ["id", "url"] + assert response.status_code == 201 + assert expected_keys == list(response.json().keys()) + assert len(Server.get_all(self.db)) == initial_count + 1 + + # Test trying to create a different server with the same URL + # returns a 400 response + response, _ = self._create_server() + assert response.status_code == 400 + assert response.json() == { + "detail": "Server with url 'http://testserver' already exists." + } + self._cleanup_server() + + def test_retrieve_server(self): + """ + Test the retrieve server configuration routes + """ + response, _ = self._create_server() + expected_response = response.json() + server_id = expected_response.get("id") + response = self.client.get(f"/api/v1/servers/{server_id}") + assert response.status_code == 200 + assert response.json() == expected_response + self._cleanup_server() diff --git a/app/tests/test_base.py b/app/tests/test_base.py new file mode 100644 index 0000000..83b68d4 --- /dev/null +++ b/app/tests/test_base.py @@ -0,0 +1,35 @@ +import os + +import fakeredis +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm.session import sessionmaker + +from app.main import app +from app.database import Base + + +SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db" + +# Delete existing test database +if os.path.exists("./test.db"): + os.remove("./test.db") + +engine = create_engine( + SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} +) +TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +Base.metadata.create_all(bind=engine) + + +class TestBase: + @classmethod + def setup_class(cls): + cls.client = TestClient(app=app) + cls.db = TestingSessionLocal() + cls.redis_client = fakeredis.FakeRedis() + + @classmethod + def teardown_class(cls): + cls.db.close() diff --git a/app/tests/test_main.py b/app/tests/test_main.py new file mode 100644 index 0000000..7aea20d --- /dev/null +++ b/app/tests/test_main.py @@ -0,0 +1,14 @@ +from app.tests.test_base import TestBase +from app.settings import settings + + +class TestMain(TestBase): + def test_home_route(self): + response = self.client.get("/") + assert response.status_code == 200 + assert response.json() == { + "app_name": settings.app_name, + "app_description": settings.app_description, + "app_version": settings.app_version, + "docs_url": "http://testserver/docs", + } diff --git a/app/tests/utils/test_hyper_utils.py b/app/tests/utils/test_hyper_utils.py new file mode 100644 index 0000000..265ebe7 --- /dev/null +++ b/app/tests/utils/test_hyper_utils.py @@ -0,0 +1,117 @@ +""" +TEsts for the hyper_utils module +""" +from unittest.mock import patch, MagicMock +from sqlalchemy.orm.attributes import flag_modified + +from app.common_tags import JOB_ID_METADATA, SYNC_FAILURES_METADATA +from app.models import HyperFile +from app.schemas import FileCreate, FileStatusEnum +from app.settings import settings +from app.tests.test_base import TestBase +from app.utils.hyper_utils import ( + schedule_hyper_file_cron_job, + cancel_hyper_file_job, + handle_hyper_file_job_completion, +) + + +class TestHyperUtils(TestBase): + @patch("app.utils.hyper_utils.schedule_cron_job") + def _schedule_hyper_file_cron_job( + self, mock_schedule_cron_job, user, job_mock: MagicMock = MagicMock + ): + def dummy_func(a: str): + print(a) + + mock_schedule_cron_job.side_effect = job_mock + + hyperfile = HyperFile.create( + self.db, + FileCreate( + user=user.id, filename="test.hyper", is_active=True, form_id="111" + ), + ) + + schedule_hyper_file_cron_job(dummy_func, hyperfile.id, db=self.db) + self.db.refresh(hyperfile) + + assert mock_schedule_cron_job.called is True + return hyperfile + + def test_schedule_hyper_file_cron_job(self, create_user_and_login): + user, _ = create_user_and_login + job_mock = MagicMock + job_mock.id = "some_id" + hyperfile = self._schedule_hyper_file_cron_job(user=user, job_mock=job_mock) + expected_metadata = {JOB_ID_METADATA: job_mock.id, SYNC_FAILURES_METADATA: 0} + # Ensure that the HyperFiles' metadata is updated accordingly + assert hyperfile.meta_data == expected_metadata + # Clean up created hyper file + self.db.query(HyperFile).delete() + self.db.commit() + + @patch("app.utils.hyper_utils.cancel_job") + def test_cancel_hyper_file_job(self, mock_cancel_job, create_user_and_login): + user, _ = create_user_and_login + job_mock = MagicMock + job_mock.id = "some_id" + hyperfile = self._schedule_hyper_file_cron_job(user=user, job_mock=job_mock) + self.db.refresh(hyperfile) + hyperfile.meta_data[SYNC_FAILURES_METADATA] = 4 + flag_modified(hyperfile, "meta_data") + self.db.commit() + self.db.refresh(hyperfile) + + assert hyperfile.meta_data == { + JOB_ID_METADATA: job_mock.id, + SYNC_FAILURES_METADATA: 4, + } + # Ensure that cancelling a hyper file job updates it's metadata + cancel_hyper_file_job(hyperfile.id, job_mock.id, db=self.db) + self.db.refresh(hyperfile) + expected_metadata = {JOB_ID_METADATA: "", SYNC_FAILURES_METADATA: 0} + assert hyperfile.meta_data == expected_metadata + assert mock_cancel_job.called is True + self.db.query(HyperFile).delete() + self.db.commit() + + @patch("app.utils.hyper_utils.cancel_job") + def test_handle_hyper_file_job_completion( + self, mock_cancel_job, create_user_and_login + ): + user, _ = create_user_and_login + job_mock = MagicMock + job_mock.id = "some_id" + hyperfile = self._schedule_hyper_file_cron_job(user=user, job_mock=job_mock) + self.db.refresh(hyperfile) + failure_count = hyperfile.meta_data[SYNC_FAILURES_METADATA] + + # Test that the failure count is updated on job failure + handle_hyper_file_job_completion( + hyperfile.id, + self.db, + job_succeeded=False, + file_status=FileStatusEnum.latest_sync_failed, + ) + self.db.refresh(hyperfile) + assert hyperfile.meta_data[SYNC_FAILURES_METADATA] == failure_count + 1 + assert hyperfile.file_status == FileStatusEnum.latest_sync_failed + assert mock_cancel_job.called is False + + # Test job is cancelled once job failure limit is reached + hyperfile.meta_data[SYNC_FAILURES_METADATA] = settings.job_failure_limit + flag_modified(hyperfile, "meta_data") + self.db.commit() + + handle_hyper_file_job_completion( + hyperfile.id, + self.db, + job_succeeded=False, + file_status=FileStatusEnum.latest_sync_failed, + ) + self.db.refresh(hyperfile) + assert hyperfile.meta_data == {JOB_ID_METADATA: "", SYNC_FAILURES_METADATA: 0} + assert mock_cancel_job.called is True + self.db.query(HyperFile).delete() + self.db.commit() diff --git a/app/tests/utils/test_onadata_utils.py b/app/tests/utils/test_onadata_utils.py new file mode 100644 index 0000000..11a310f --- /dev/null +++ b/app/tests/utils/test_onadata_utils.py @@ -0,0 +1,45 @@ +""" +Tests for the onadata_utils module +""" +from unittest.mock import patch + +from httpx._models import Response + +from app.models import Server, User +from app.tests.test_base import TestBase +from app.utils.onadata_utils import get_access_token + + +class TestOnadataUtils(TestBase): + @patch("app.utils.onadata_utils.httpx.post") + def test_get_access_token(self, mock_httpx_post, create_user_and_login): + """ + Test the get_access_token function correctly retrieves the + access_token and resets the refresh_token + """ + user, jwt = create_user_and_login + mock_httpx_post.return_value = Response( + json={ + "refresh_token": "new_token", + "access_token": "new_access_token", + "expiresIn": "somedate", + }, + status_code=200, + ) + old_refresh_token = user.refresh_token + server = Server.get(self.db, user.server) + ret = get_access_token(user, server, self.db) + + assert ret == "new_access_token" + mock_httpx_post.assert_called_with( + f"{server.url}/o/token/", + data={ + "grant_type": "refresh_token", + "refresh_token": User.decrypt_value(user.refresh_token), + "client_id": server.client_id, + }, + auth=(server.client_id, Server.decrypt_value(server.client_secret)), + ) + user = User.get(self.db, user.id) + assert user.refresh_token != old_refresh_token + assert User.decrypt_value(user.refresh_token) == "new_token" diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/utils/auth_utils.py b/app/utils/auth_utils.py new file mode 100644 index 0000000..f1a1ed7 --- /dev/null +++ b/app/utils/auth_utils.py @@ -0,0 +1,106 @@ +# Authentication/Authorization Utilities +import uuid +from datetime import datetime, timedelta +from typing import Tuple + +import redis +import jwt +from fastapi import Request, Depends +from fastapi.exceptions import HTTPException + +from app import schemas +from app.settings import settings +from app.models import User +from app.utils.utils import get_db, get_redis_client + + +def create_session( + user: schemas.User, redis_client: redis.Redis, request: Request = None +) -> Tuple[Request, str]: + session_key = f"{user.username}-sessions" + session_id = str(uuid.uuid4()) + expiry_time = datetime.now() + timedelta(days=14) + expiry_timestamp = datetime.timestamp(expiry_time) + stored_session = session_id + f":{expiry_timestamp}" + redis_client.sadd(session_key, stored_session) + jwt_data = { + "username": user.username, + "session-id": session_id, + "server_id": user.server, + } + encoded_jwt = jwt.encode(jwt_data, settings.secret_key, algorithm="HS256") + + if request: + request.session["session-data"] = jwt_data + return request, encoded_jwt + + +class IsAuthenticatedUser: + def __init__(self, raise_errors: bool = True) -> None: + self.raise_errors = raise_errors + + def __call__( + self, request: Request, db=Depends(get_db), redis=Depends(get_redis_client) + ): + def _raise_error(exception: Exception): + if self.raise_errors: + if request.session: + request.session.clear() + raise exception + return None + + self.db = db + self.redis_client = redis + + session_data = request.session.get("session-data") + invalid_credentials_error = HTTPException( + status_code=401, detail="Invalid authentication credentials" + ) + if not session_data: + auth = request.headers.get("authorization") + if auth: + auth_type, value = auth.split(" ") + + if auth_type != "Bearer": + return _raise_error(invalid_credentials_error) + + try: + session_data = jwt.decode( + value, settings.secret_key, algorithms=["HS256"] + ) + except jwt.DecodeError: + return _raise_error(invalid_credentials_error) + + if session_data: + session_id = session_data.get("session-id") + username = session_data.get("username") + server_id = session_data.get("server_id") + session_key = f"{username}-sessions" + + if self.is_valid_session(session_id=session_id, session_key=session_key): + user = User.get_using_server_and_username(self.db, username, server_id) + if not user: + return _raise_error(invalid_credentials_error) + return user + + return _raise_error( + HTTPException(status_code=401, detail="Authentication required") + ) + + def is_valid_session(self, session_key: str, session_id: str) -> bool: + sessions = self.redis_client.smembers(session_key) + for session in sessions: + sess_id, expiry = session.decode("utf-8").split(":") + + try: + expiry = int(expiry) + except ValueError: + expiry = float(expiry) + + if expiry and datetime.fromtimestamp(expiry) > datetime.now(): + if sess_id == session_id: + return True + else: + self.redis_client.srem(session_key, session) + + return False diff --git a/app/utils/hyper_utils.py b/app/utils/hyper_utils.py new file mode 100644 index 0000000..ad22b51 --- /dev/null +++ b/app/utils/hyper_utils.py @@ -0,0 +1,251 @@ +""" +File containing utility functions related to a Hyper Database / HyperFile +""" +from typing import List, Callable +from pathlib import Path + +import pandas as pd +from datetime import datetime +from pandas.errors import EmptyDataError +from rq.job import Job +from sqlalchemy.orm.session import Session +from sqlalchemy.orm.attributes import flag_modified +from tableauhyperapi import ( + SqlType, + Connection, + HyperProcess, + TableName, + escape_string_literal, + TableDefinition, + CreateMode, +) + +from app.database import SessionLocal +from app.schemas import FileStatusEnum +from app.settings import settings +from app.common_tags import JOB_ID_METADATA, SYNC_FAILURES_METADATA +from app.jobs.scheduler import schedule_cron_job, cancel_job +from app.libs.s3.client import S3Client +from app.libs.tableau.client import TableauClient +from app.models import HyperFile, Configuration + + +def element_type_to_hyper_sql_type(elem_type: str) -> SqlType: + type_map = { + "integer": SqlType.big_int, + "decimal": SqlType.double, + "text": SqlType.text, + } + return type_map.get(elem_type) + + +def _pandas_type_to_hyper_sql_type(_type: str) -> SqlType: + # Only supports text and numeric fields, more may be added later + type_map = { # noqa + "b": SqlType.text, + "i": SqlType.big_int, + "u": SqlType.text, + "f": SqlType.double, + "c": SqlType.text, + "O": SqlType.text, + "S": SqlType.text, + "a": SqlType.text, + "U": SqlType.text, + } + return type_map.get(_type) + + +def _import_csv_to_hyperfile( + path: str, + csv_path: str, + process: HyperProcess, + table_name: TableName = TableName("Extract", "Extract"), + null_field: str = "NULL", + delimiter: str = ",", +) -> int: + """ + Imports CSV data into a HyperFile + """ + with Connection(endpoint=process.endpoint, database=path) as connection: + command = ( + f"COPY {table_name} from {escape_string_literal(csv_path)} with " + f"(format csv, NULL '{null_field}', delimiter '{delimiter}', header)" + ) + count = connection.execute_command(command=command) + return count + + +def _prep_csv_for_import(csv_path: Path) -> List[TableDefinition.Column]: + """ + Creates a schema definition from an Onadata CSV Export + DISCLAIMER: This function doesn't actually try to derive the columns + type. It returns every column as a string column + """ + columns: List[SqlType] = [] + df = pd.read_csv(csv_path, na_values=["n/a", ""]) + df = df.convert_dtypes() + for name, dtype in df.dtypes.iteritems(): + column = TableDefinition.Column( + name, _pandas_type_to_hyper_sql_type(dtype.kind)() + ) + columns.append(column) + # Save dataframe to CSV as the dataframe is more cleaner + # in most cases. We also don't want the headers to be within + # the CSV as Hyper picks the header as a value + with open(csv_path, "w") as f: + f.truncate(0) + df.to_csv(csv_path, na_rep="NULL", header=True, index=False) + return columns + + +def handle_csv_import_to_hyperfile( + hyperfile: HyperFile, csv_path: str, process: HyperProcess, db: Session +) -> int: + file_path = hyperfile.retrieve_latest_file(db) + s3_destination = hyperfile.get_file_path(db) + configuration = hyperfile.configuration + + return handle_csv_import( + file_path=file_path, + csv_path=csv_path, + process=process, + configuration=configuration, + s3_destination=s3_destination, + ) + + +def handle_csv_import( + file_path: str, + csv_path: Path, + process: HyperProcess, + configuration: Configuration = None, + s3_destination: str = None, +) -> int: + """ + Handles CSV Import to Hyperfile + """ + table_name = TableName("Extract", "Extract") + try: + columns = _prep_csv_for_import(csv_path=csv_path) + except EmptyDataError: + return 0 + else: + with Connection( + endpoint=process.endpoint, + database=file_path, + create_mode=CreateMode.CREATE_AND_REPLACE, + ) as connection: + connection.catalog.create_schema("Extract") + extract_table = TableDefinition(table_name, columns=columns) + connection.catalog.create_table(extract_table) + + import_count = _import_csv_to_hyperfile( + path=file_path, + csv_path=str(csv_path), + table_name=table_name, + process=process, + ) + + # Store hyper file in S3 Storage + s3_client = S3Client() + s3_client.upload(file_path, s3_destination or Path(file_path).name) + + if configuration: + tableau_client = TableauClient(configuration=configuration) + tableau_client.publish_hyper(file_path) + + return import_count + + +def schedule_hyper_file_cron_job( + job_func: Callable, + hyperfile_id: int, + extra_job_args: list = [], + job_id_meta_tag: str = JOB_ID_METADATA, + job_failure_counter_meta_tag: str = SYNC_FAILURES_METADATA, + db: SessionLocal = SessionLocal(), +) -> Job: + """ + Schedules a Job that should run on a cron schedule for a particular + Hyperfile + """ + hf: HyperFile = HyperFile.get(db, hyperfile_id) + metadata = hf.meta_data or {} + + job: Job = schedule_cron_job(job_func, [hyperfile_id] + extra_job_args) + # Set meta tags to help track the started CRON Job + metadata[job_id_meta_tag] = job.id + metadata[job_failure_counter_meta_tag] = 0 + + hf.meta_data = metadata + flag_modified(hf, "meta_data") + db.commit() + return job + + +def cancel_hyper_file_job( + hyperfile_id: int, + job_id: str, + db: SessionLocal = SessionLocal(), + job_name: str = "app.utils.onadata_utils.start_csv_import_to_hyper_job", + job_id_meta_tag: str = JOB_ID_METADATA, + job_failure_counter_meta_tag: str = SYNC_FAILURES_METADATA, +) -> None: + """ + Cancels a scheduler Job related to a Hyper file and resets the job failure + counter and meta tag + """ + hf: HyperFile = HyperFile.get(db, hyperfile_id) + metadata = hf.meta_data or {} + + cancel_job(job_id, [hyperfile_id], job_name) + if metadata.get(job_id_meta_tag): + metadata[job_id_meta_tag] = "" + metadata[job_failure_counter_meta_tag] = 0 + hf.meta_data = metadata + flag_modified(hf, "meta_data") + db.commit() + + +def handle_hyper_file_job_completion( + hyperfile_id: int, + db: SessionLocal = SessionLocal(), + job_succeeded: bool = True, + object_updated: bool = True, + file_status: str = FileStatusEnum.file_available.value, + job_id_meta_tag: str = JOB_ID_METADATA, + job_failure_counter_meta_tag: str = SYNC_FAILURES_METADATA, +): + """ + Handles updating a HyperFile according to the outcome of a running Job; Updates + file status & tracks the jobs current failure counter. + """ + hf: HyperFile = HyperFile.get(db, hyperfile_id) + metadata = hf.meta_data or {} + + if job_succeeded: + if object_updated: + hf.last_updated = datetime.now() + metadata[job_failure_counter_meta_tag] = 0 + else: + failure_count = metadata.get(job_failure_counter_meta_tag) + if isinstance(failure_count, int): + metadata[job_failure_counter_meta_tag] = failure_count + 1 + else: + metadata[job_failure_counter_meta_tag] = failure_count = 0 + + if failure_count >= settings.job_failure_limit and hf.is_active: + cancel_hyper_file_job( + hyperfile_id, + metadata.get(job_id_meta_tag), + db=db, + job_id_meta_tag=job_id_meta_tag, + job_failure_counter_meta_tag=job_failure_counter_meta_tag, + ) + db.refresh(hf) + hf.is_active = False + + hf.meta_data = metadata + hf.file_status = file_status + flag_modified(hf, "meta_data") + db.commit() diff --git a/app/utils/onadata_utils.py b/app/utils/onadata_utils.py new file mode 100644 index 0000000..b55aed1 --- /dev/null +++ b/app/utils/onadata_utils.py @@ -0,0 +1,250 @@ +# Utility functions for Ona Data Aggregate Servers +import time +from pathlib import Path +from tempfile import NamedTemporaryFile +from typing import Optional + +import httpx +import sentry_sdk +from fastapi_cache import caches +from sqlalchemy.orm.session import Session +from tableauhyperapi import HyperProcess, Telemetry + +from app import schemas +from app.common_tags import ( + ONADATA_TOKEN_ENDPOINT, + ONADATA_FORMS_ENDPOINT, + ONADATA_USER_ENDPOINT, + JOB_ID_METADATA, + HYPER_PROCESS_CACHE_KEY, +) +from app.database import SessionLocal +from app.models import HyperFile, Server, User +from app.settings import settings +from app.utils.hyper_utils import ( + handle_csv_import_to_hyperfile, + handle_hyper_file_job_completion, + schedule_hyper_file_cron_job, +) + + +class UnsupportedForm(Exception): + pass + + +class ConnectionRequestError(Exception): + pass + + +class CSVExportFailure(Exception): + pass + + +class DoesNotExist(Exception): + pass + + +def get_access_token(user: User, server: Server, db: SessionLocal) -> Optional[str]: + url = f"{server.url}{ONADATA_TOKEN_ENDPOINT}" + data = { + "grant_type": "refresh_token", + "refresh_token": user.decrypt_value(user.refresh_token), + "client_id": server.client_id, + } + resp = httpx.post( + url, + data=data, + auth=(server.client_id, server.decrypt_value(server.client_secret)), + ) + if resp.status_code == 200: + resp = resp.json() + user = User.get(db, user.id) + user.refresh_token = user.encrypt_value(resp.get("refresh_token")) + db.commit() + return resp.get("access_token") + return None + + +def _get_csv_export( + url: str, headers: dict = None, temp_token: str = None, retries: int = 0 +): + def _write_export_to_temp_file(export_url, headers, retry: int = 0): + print("Writing to temporary CSV Export to temporary file.") + retry = 0 or retry + status = 0 + with NamedTemporaryFile(delete=False, suffix=".csv") as export: + with httpx.stream("GET", export_url, headers=headers) as response: + if response.status_code == 200: + for chunk in response.iter_bytes(): + export.write(chunk) + return export + status = response.status_code + if retry < 3: + print( + f"Retrying export write: Status {status}, Retry {retry}, URL {export_url}" + ) + _write_export_to_temp_file( + export_url=export_url, headers=headers, retry=retry + 1 + ) + + print("Checking on export status.") + resp = httpx.get(url, headers=headers) + + if resp.status_code == 202: + resp = resp.json() + job_status = resp.get("job_status") + if "export_url" in resp and job_status == "SUCCESS": + export_url = resp.get("export_url") + if temp_token: + export_url += f"&temp_token={temp_token}" + return _write_export_to_temp_file(export_url, headers) + elif job_status == "FAILURE": + reason = resp.get("progress") + raise CSVExportFailure(f"CSV Export Failure. Reason: {reason}") + + job_uuid = resp.get("job_uuid") + if job_uuid: + print(f"Waiting for CSV Export to be ready. Job UUID: {job_uuid}") + url += f"&job_uuid={job_uuid}" + + if retries < 3: + time.sleep(30 * (retries + 1)) + return _get_csv_export( + url, headers=headers, temp_token=temp_token, retries=retries + 1 + ) + else: + raise ConnectionRequestError( + f"Failed to retrieve CSV Export. URL: {url}, took too long for CSV Export to be ready" + ) + else: + raise ConnectionRequestError( + f"Failed to retrieve CSV Export. URL: {url}, Status Code: {resp.status_code}" + ) + + +def get_csv_export( + hyperfile: HyperFile, user: schemas.User, server: schemas.Server, db: SessionLocal +) -> str: + """ + Retrieves a CSV Export for an XForm linked to a Hyperfile + """ + bearer_token = get_access_token(user, server, db) + headers = { + "user-agent": f"{settings.app_name}/{settings.app_version}", + "Authorization": f"Bearer {bearer_token}", + } + form_url = f"{server.url}{ONADATA_FORMS_ENDPOINT}/{hyperfile.form_id}" + resp = httpx.get(form_url + ".json", headers=headers) + if resp.status_code == 200: + form_data = resp.json() + public = form_data.get("public") + url = f"{form_url}/export_async.json?format=csv" + temp_token = None + + # Retrieve auth credentials if XForm is private + # Onadatas' Export Endpoint only support TempToken or Basic Authentication + if not public: + resp = httpx.get( + f"{server.url}{ONADATA_USER_ENDPOINT}.json", headers=headers + ) + temp_token = resp.json().get("temp_token") + csv_export = _get_csv_export(url, headers, temp_token) + if csv_export: + return Path(csv_export.name) + + +def start_csv_import_to_hyper( + hyperfile_id: int, process: HyperProcess, schedule_cron: bool = True +): + db = SessionLocal() + hyperfile: HyperFile = HyperFile.get(db, object_id=hyperfile_id) + job_status: str = schemas.FileStatusEnum.file_available.value + err: Exception = None + + if hyperfile: + user = User.get(db, hyperfile.user) + server = Server.get(db, user.server) + + hyperfile.file_status = schemas.FileStatusEnum.syncing.value + db.commit() + db.refresh(hyperfile) + + try: + export = get_csv_export(hyperfile, user, server, db) + + if export: + handle_csv_import_to_hyperfile(hyperfile, export, process, db) + + if schedule_cron and not hyperfile.meta_data.get(JOB_ID_METADATA): + schedule_hyper_file_cron_job( + start_csv_import_to_hyper_job, hyperfile_id + ) + else: + job_status = schemas.FileStatusEnum.file_unavailable.value + except (CSVExportFailure, ConnectionRequestError, Exception) as exc: + err = exc + job_status = schemas.FileStatusEnum.latest_sync_failed.value + + successful_import = job_status == schemas.FileStatusEnum.file_available.value + handle_hyper_file_job_completion( + hyperfile.id, + db, + job_succeeded=successful_import, + object_updated=successful_import, + file_status=job_status, + ) + db.close() + if err: + sentry_sdk.capture_exception(err) + return successful_import + + +def start_csv_import_to_hyper_job(hyperfile_id: int, schedule_cron: bool = False): + if not caches.get(HYPER_PROCESS_CACHE_KEY): + caches.set( + HYPER_PROCESS_CACHE_KEY, + HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU), + ) + process: HyperProcess = caches.get(HYPER_PROCESS_CACHE_KEY) + start_csv_import_to_hyper(hyperfile_id, process, schedule_cron=schedule_cron) + + +def create_or_get_hyperfile( + db: Session, file_data: schemas.FileCreate, process: HyperProcess +): + hyperfile = HyperFile.get_using_file_create(db, file_data) + if hyperfile: + return hyperfile, False + + headers = {"user-agent": f"{settings.app_name}/{settings.app_version}"} + user = User.get(db, file_data.user) + server = Server.get(db, user.server) + bearer_token = get_access_token(user, server, db) + headers.update({"Authorization": f"Bearer {bearer_token}"}) + + url = f"{server.url}{ONADATA_FORMS_ENDPOINT}/{file_data.form_id}.json" + resp = httpx.get(url, headers=headers) + + if resp.status_code == 200: + resp = resp.json() + if "public_key" in resp and resp.get("public_key"): + raise UnsupportedForm("Encrypted forms are not supported") + + title = resp.get("title") + file_data.filename = f"{title}.hyper" + return HyperFile.create(db, file_data), True + else: + raise ConnectionRequestError( + f"Currently unable to start connection to form. Aggregate status code: {resp.status_code}" + ) + + +def schedule_all_active_forms(db: Session = SessionLocal(), close_db: bool = False): + """ + Schedule CSV Import Jobs for all active Hyper Files + """ + for hf in HyperFile.get_active_files(db): + schedule_hyper_file_cron_job(start_csv_import_to_hyper_job, hf.id) + + if close_db: + db.close() diff --git a/app/utils/utils.py b/app/utils/utils.py new file mode 100644 index 0000000..728d853 --- /dev/null +++ b/app/utils/utils.py @@ -0,0 +1,23 @@ +# Common/General Utilities +import redis + +from app.database import SessionLocal +from app.settings import settings + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() + + +def get_redis_client(): + redis_client = redis.Redis( + host=settings.redis_host, port=settings.redis_port, db=settings.redis_db + ) + try: + yield redis_client + finally: + pass diff --git a/dev-requirements.in b/dev-requirements.in new file mode 100644 index 0000000..2d94fc4 --- /dev/null +++ b/dev-requirements.in @@ -0,0 +1,9 @@ +ipdb +pip-tools +black +flake8 +pre-commit +pytest +pytest-cov +tox +fakeredis diff --git a/dev-requirements.pip b/dev-requirements.pip new file mode 100644 index 0000000..bac496f --- /dev/null +++ b/dev-requirements.pip @@ -0,0 +1,134 @@ +# +# This file is autogenerated by pip-compile +# To update, run: +# +# pip-compile --output-file=dev-requirements.pip dev-requirements.in +# +appdirs==1.4.4 + # via + # black + # virtualenv +attrs==20.3.0 + # via pytest +backcall==0.2.0 + # via ipython +black==20.8b1 + # via -r dev-requirements.in +cfgv==3.2.0 + # via pre-commit +click==7.1.2 + # via + # black + # pip-tools +coverage==5.4 + # via pytest-cov +decorator==4.4.2 + # via ipython +distlib==0.3.1 + # via virtualenv +fakeredis==1.4.5 + # via -r dev-requirements.in +filelock==3.0.12 + # via + # tox + # virtualenv +flake8==3.8.4 + # via -r dev-requirements.in +identify==1.5.9 + # via pre-commit +iniconfig==1.1.1 + # via pytest +ipdb==0.13.4 + # via -r dev-requirements.in +ipython-genutils==0.2.0 + # via traitlets +ipython==7.18.1 + # via ipdb +jedi==0.17.2 + # via ipython +mccabe==0.6.1 + # via flake8 +mypy-extensions==0.4.3 + # via black +nodeenv==1.5.0 + # via pre-commit +packaging==20.7 + # via + # pytest + # tox +parso==0.7.1 + # via jedi +pathspec==0.8.1 + # via black +pexpect==4.8.0 + # via ipython +pickleshare==0.7.5 + # via ipython +pip-tools==5.3.1 + # via -r dev-requirements.in +pluggy==0.13.1 + # via + # pytest + # tox +pre-commit==2.8.2 + # via -r dev-requirements.in +prompt-toolkit==3.0.8 + # via ipython +ptyprocess==0.6.0 + # via pexpect +py==1.9.0 + # via + # pytest + # tox +pycodestyle==2.6.0 + # via flake8 +pyflakes==2.2.0 + # via flake8 +pygments==2.7.2 + # via ipython +pyparsing==2.4.7 + # via packaging +pytest-cov==2.11.1 + # via -r dev-requirements.in +pytest==6.1.2 + # via + # -r dev-requirements.in + # pytest-cov +pyyaml==5.3.1 + # via pre-commit +redis==3.5.3 + # via fakeredis +regex==2020.10.28 + # via black +six==1.15.0 + # via + # fakeredis + # pip-tools + # tox + # virtualenv +sortedcontainers==2.3.0 + # via fakeredis +toml==0.10.2 + # via + # black + # pre-commit + # pytest + # tox +tox==3.20.1 + # via -r dev-requirements.in +traitlets==5.0.5 + # via ipython +typed-ast==1.4.1 + # via black +typing-extensions==3.7.4.3 + # via black +virtualenv==20.1.0 + # via + # pre-commit + # tox +wcwidth==0.2.5 + # via prompt-toolkit + +# The following packages are considered to be unsafe in a requirements file: +# pip +# setuptools diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..1c3df0b --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,76 @@ +version: "3" + +services: + cache: + image: redis:6-alpine + ports: + - 6379:6379 + db: + image: postgres:13 + volumes: + - ../.duva_db:/var/lib/postgresql/data + environment: + - POSTGRES_PASSWORD=duva + - POSTGRES_USER=duva + - POSTGRES_DB=duva + app: + build: + context: . + dockerfile: Dockerfile + image: hypermind:latest + stdin_open: true + tty: true + volumes: + # For local development + - .:/app + - ~/.aws:/root/.aws + ports: + - "8000:80" + depends_on: + - cache + - db + environment: + - REDIS_URL=redis://cache + - MEDIA_PATH=/app/media + - REDIS_URL=redis://cache:6379/1 + - QUEUE_NAME=default + - CRON_SCHEDULE=*/30 * * * * + - DEBUG=True + - DATABASE_URL=postgresql://duva:duva@db/duva + - RUN_MIGRATION=True + scheduler: + build: + context: . + dockerfile: Dockerfile + image: hypermind:latest + command: "bash init_scheduler.sh" + volumes: + # For local development + - .:/app + - ~/.aws:/root/.aws + depends_on: + - cache + - db + environment: + - REDIS_URL=redis://cache:6379/1 + - QUEUE_NAME=default + - CRON_SCHEDULE=*/30 * * * * + - SCHEDULE_ALL=False + - DATABASE_URL=postgresql://duva:duva@db/duva + worker: + build: + context: . + dockerfile: Dockerfile + image: hypermind:latest + command: "rq worker -c app.jobs.settings" + volumes: + # For local development + - .:/app + - ~/.aws:/root/.aws + depends_on: + - cache + - db + environment: + - REDIS_URL=redis://cache:6379/1 + - QUEUE_NAME=default + - DATABASE_URL=postgresql://duva:duva@db/duva diff --git a/docs/flow-diagrams/managed-hyper-database-flow.png b/docs/flow-diagrams/managed-hyper-database-flow.png new file mode 100644 index 0000000..ecc4f63 Binary files /dev/null and b/docs/flow-diagrams/managed-hyper-database-flow.png differ diff --git a/docs/flow-diagrams/one-off-hyper-database-flow.png b/docs/flow-diagrams/one-off-hyper-database-flow.png new file mode 100644 index 0000000..1d0f652 Binary files /dev/null and b/docs/flow-diagrams/one-off-hyper-database-flow.png differ diff --git a/docs/flow-diagrams/server-registration-flow.png b/docs/flow-diagrams/server-registration-flow.png new file mode 100644 index 0000000..f4c2ad1 Binary files /dev/null and b/docs/flow-diagrams/server-registration-flow.png differ diff --git a/init_scheduler.sh b/init_scheduler.sh new file mode 100755 index 0000000..60fae09 --- /dev/null +++ b/init_scheduler.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash + +python3 app/jobs/scheduler.py && rqscheduler --host cache --port 6379 --db 1 \ No newline at end of file diff --git a/prestart.sh b/prestart.sh new file mode 100755 index 0000000..08c94fb --- /dev/null +++ b/prestart.sh @@ -0,0 +1,6 @@ +#! /usr/bin/env bash + +# Run migrations +if [ "$RUN_MIGRATIONS" = "True" ]; then + alembic upgrade head +fi diff --git a/requirements.in b/requirements.in new file mode 100644 index 0000000..e793dd6 --- /dev/null +++ b/requirements.in @@ -0,0 +1,18 @@ +fastapi[all] # Should probably switch this to fastapi +alembic +httpx +tableauhyperapi +fastapi-cache +pyxform +python-dateutil +pandas +boto3 +tableauserverclient +sqlalchemy +cryptography +rq +rq-scheduler +psycopg2-binary +redis +pyjwt +sentry-sdk diff --git a/requirements.pip b/requirements.pip new file mode 100644 index 0000000..095c2fb --- /dev/null +++ b/requirements.pip @@ -0,0 +1,202 @@ +# +# This file is autogenerated by pip-compile +# To update, run: +# +# pip-compile --output-file=requirements.pip requirements.in +# +aiofiles==0.5.0 + # via fastapi +aioredis==1.3.1 + # via fastapi-cache +alembic==1.4.3 + # via -r requirements.in +aniso8601==7.0.0 + # via graphene +argparse==1.4.0 + # via unittest2 +async-exit-stack==1.0.1 + # via fastapi +async-generator==1.10 + # via fastapi +async-timeout==3.0.1 + # via aioredis +boto3==1.16.12 + # via -r requirements.in +botocore==1.19.12 + # via + # boto3 + # s3transfer +certifi==2020.6.20 + # via + # httpx + # requests + # sentry-sdk +cffi==1.14.2 + # via + # cryptography + # tableauhyperapi +chardet==3.0.4 + # via requests +click==7.1.2 + # via + # rq + # uvicorn +croniter==0.3.36 + # via rq-scheduler +cryptography==3.2.1 + # via -r requirements.in +dnspython==2.0.0 + # via email-validator +email-validator==1.1.1 + # via fastapi +fastapi-cache==0.0.5 + # via -r requirements.in +fastapi[all]==0.61.1 + # via -r requirements.in +formencode==2.0.0 + # via pyxform +graphene==2.1.8 + # via fastapi +graphql-core==2.3.2 + # via + # graphene + # graphql-relay +graphql-relay==2.0.1 + # via graphene +h11==0.9.0 + # via + # httpcore + # uvicorn +hiredis==1.1.0 + # via aioredis +httpcore==0.12.0 + # via httpx +httptools==0.1.1 + # via uvicorn +httpx==0.16.1 + # via -r requirements.in +idna==2.10 + # via + # email-validator + # requests + # rfc3986 +itsdangerous==1.1.0 + # via fastapi +jinja2==2.11.2 + # via fastapi +jmespath==0.10.0 + # via + # boto3 + # botocore +linecache2==1.0.0 + # via traceback2 +mako==1.1.3 + # via alembic +markupsafe==1.1.1 + # via + # jinja2 + # mako +natsort==7.0.1 + # via croniter +numpy==1.19.4 + # via pandas +orjson==3.4.2 + # via fastapi +pandas==1.1.4 + # via -r requirements.in +promise==2.3 + # via + # graphql-core + # graphql-relay +psycopg2-binary==2.8.6 + # via -r requirements.in +pycparser==2.20 + # via cffi +pydantic==1.7.1 + # via fastapi +pyjwt==1.7.1 + # via -r requirements.in +python-dateutil==2.8.1 + # via + # -r requirements.in + # alembic + # botocore + # croniter + # pandas +python-editor==1.0.4 + # via alembic +python-multipart==0.0.5 + # via fastapi +pytz==2020.4 + # via pandas +pyxform==1.2.0 + # via -r requirements.in +pyyaml==5.3.1 + # via fastapi +redis==3.5.3 + # via + # -r requirements.in + # rq +requests==2.24.0 + # via + # fastapi + # tableauserverclient +rfc3986[idna2008]==1.4.0 + # via httpx +rq-scheduler==0.10.0 + # via -r requirements.in +rq==1.6.1 + # via + # -r requirements.in + # rq-scheduler +rx==1.6.1 + # via graphql-core +s3transfer==0.3.3 + # via boto3 +sentry-sdk==0.19.5 + # via -r requirements.in +six==1.15.0 + # via + # cryptography + # formencode + # graphene + # graphql-core + # graphql-relay + # python-dateutil + # python-multipart + # unittest2 +sniffio==1.2.0 + # via + # httpcore + # httpx +sqlalchemy==1.3.20 + # via + # -r requirements.in + # alembic +starlette==0.13.6 + # via fastapi +tableauhyperapi==0.0.11556 + # via -r requirements.in +tableauserverclient==0.13 + # via -r requirements.in +traceback2==1.4.0 + # via unittest2 +ujson==3.2.0 + # via fastapi +unicodecsv==0.14.1 + # via pyxform +unittest2==1.1.0 + # via pyxform +urllib3==1.25.11 + # via + # botocore + # requests + # sentry-sdk +uvicorn==0.11.8 + # via fastapi +uvloop==0.14.0 + # via uvicorn +websockets==8.1 + # via uvicorn +xlrd==1.2.0 + # via pyxform diff --git a/scripts/make-migrations.sh b/scripts/make-migrations.sh new file mode 100755 index 0000000..d6fced4 --- /dev/null +++ b/scripts/make-migrations.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +PYTHONPATH=. alembic revision --autogenerate -m "$@" diff --git a/scripts/migrate.sh b/scripts/migrate.sh new file mode 100755 index 0000000..64b69df --- /dev/null +++ b/scripts/migrate.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +PYTHONPATH=. alembic upgrade head diff --git a/scripts/run-tests.sh b/scripts/run-tests.sh new file mode 100755 index 0000000..3de3035 --- /dev/null +++ b/scripts/run-tests.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +PYTHONPATH=. pytest --cov-config=.coveragerc --cov=app -s app/tests diff --git a/scripts/start.sh b/scripts/start.sh new file mode 100755 index 0000000..fceb854 --- /dev/null +++ b/scripts/start.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +PYTHONPATH=. python app/main.py diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..6f5f841 --- /dev/null +++ b/setup.py @@ -0,0 +1,12 @@ +from distutils.core import setup + + +setup( + name="Duva", + version="0.0.1", + description="", + author="Ona Kenya", + license="Apache 2.0", + author_email="tech@ona.io", + url="https://github.com/onaio/hypermind", +) diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..c74b077 --- /dev/null +++ b/tox.ini @@ -0,0 +1,15 @@ +[tox] +envlist = py36,py37,py38,lint + +[testenv] +deps = + -rrequirements.pip + -rdev-requirements.pip +commands = + sh -c './scripts/run-tests.sh' + +[testenv:lint] +deps = -rdev-requirements.pip +commands = + black --check app + flake8 app