Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat) text wrapper for querys #72

Merged
merged 2 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions utils/database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import streamlit as st

import pandas as pd
from sqlalchemy import create_engine
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker

from utils.data_manipulation import StrategyData
Expand Down Expand Up @@ -64,13 +64,13 @@ def configs(self):
def get_config_files(self):
with self.session_maker() as session:
query = 'SELECT DISTINCT config_file_path FROM TradeFill'
config_files = pd.read_sql_query(query, session.connection())
config_files = pd.read_sql_query(text(query), session.connection())
return config_files['config_file_path'].tolist()

def get_exchanges_trading_pairs_by_config_file(self, config_file_path):
with self.session_maker() as session:
query = f"SELECT DISTINCT market, symbol FROM TradeFill WHERE config_file_path = '{config_file_path}'"
exchanges_trading_pairs = pd.read_sql_query(query, session.connection())
exchanges_trading_pairs = pd.read_sql_query(text(query), session.connection())
exchanges_trading_pairs["market"] = exchanges_trading_pairs["market"].apply(
lambda x: x.lower().replace("_papertrade", ""))
exchanges_trading_pairs = exchanges_trading_pairs.groupby("market")["symbol"].apply(list).to_dict()
Expand Down Expand Up @@ -134,7 +134,7 @@ def _get_market_data_query(start_date=None, end_date=None):
def get_orders(self, config_file_path=None, start_date=None, end_date=None):
with self.session_maker() as session:
query = self._get_orders_query(config_file_path, start_date, end_date)
orders = pd.read_sql_query(query, session.connection())
orders = pd.read_sql_query(text(query), session.connection())
orders["market"] = orders["market"].apply(lambda x: x.lower().replace("_papertrade", ""))
orders["amount"] = orders["amount"] / 1e6
orders["price"] = orders["price"] / 1e6
Expand All @@ -147,7 +147,7 @@ def get_trade_fills(self, config_file_path=None, start_date=None, end_date=None)
float_cols = ["amount", "price", "trade_fee_in_quote"]
with self.session_maker() as session:
query = self._get_trade_fills_query(config_file_path, start_date, end_date)
trade_fills = pd.read_sql_query(query, session.connection())
trade_fills = pd.read_sql_query(text(query), session.connection())
trade_fills[float_cols] = trade_fills[float_cols] / 1e6
trade_fills["cum_fees_in_quote"] = trade_fills.groupby(groupers)["trade_fee_in_quote"].cumsum()
trade_fills["net_amount"] = trade_fills['amount'] * trade_fills['trade_type'].apply(lambda x: 1 if x == 'BUY' else -1)
Expand All @@ -168,13 +168,13 @@ def get_trade_fills(self, config_file_path=None, start_date=None, end_date=None)
def get_order_status(self, order_ids=None, start_date=None, end_date=None):
with self.session_maker() as session:
query = self._get_order_status_query(order_ids, start_date, end_date)
order_status = pd.read_sql_query(query, session.connection())
order_status = pd.read_sql_query(text(query), session.connection())
return order_status

def get_market_data(self, start_date=None, end_date=None):
with self.session_maker() as session:
query = self._get_market_data_query(start_date, end_date)
market_data = pd.read_sql_query(query, session.connection())
market_data = pd.read_sql_query(text(query), session.connection())
market_data["timestamp"] = pd.to_datetime(market_data["timestamp"] / 1e6, unit="ms")
market_data.set_index("timestamp", inplace=True)
market_data["mid_price"] = market_data["mid_price"] / 1e6
Expand Down
34 changes: 17 additions & 17 deletions utils/optuna_database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json

import pandas as pd
from sqlalchemy import create_engine
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker

from utils.data_manipulation import StrategyData
Expand All @@ -20,7 +20,7 @@ def status(self):
try:
with self.session_maker() as session:
query = 'SELECT * FROM trials WHERE state = "COMPLETE"'
completed_trials = pd.read_sql_query(query, session.connection())
completed_trials = pd.read_sql_query(text(query), session.connection())
if len(completed_trials) > 0:
# TODO: improve error handling, think what to do with other cases
return "OK"
Expand All @@ -37,7 +37,7 @@ def _get_tables(self):
try:
with self.session_maker() as session:
query = "SELECT name FROM sqlite_master WHERE type='table';"
tables = pd.read_sql_query(query, session.connection())
tables = pd.read_sql_query(text(query), session.connection())
return tables["name"].tolist()
except Exception as e:
return f"Error: {str(e)}"
Expand All @@ -49,7 +49,7 @@ def trials(self):
def _get_trials_table(self):
try:
with self.session_maker() as session:
df = pd.read_sql_query("SELECT * FROM trials", session.connection())
df = pd.read_sql_query(text("SELECT * FROM trials"), session.connection())
return df
except Exception as e:
return f"Error: {str(e)}"
Expand All @@ -61,7 +61,7 @@ def studies(self):
def _get_studies_table(self):
try:
with self.session_maker() as session:
df = pd.read_sql_query("SELECT * FROM studies", session.connection())
df = pd.read_sql_query(text("SELECT * FROM studies"), session.connection())
return df
except Exception as e:
return f"Error: {str(e)}"
Expand All @@ -73,7 +73,7 @@ def trial_params(self):
def _get_trial_params_table(self):
try:
with self.session_maker() as session:
df = pd.read_sql_query("SELECT * FROM trial_params", session.connection())
df = pd.read_sql_query(text("SELECT * FROM trial_params"), session.connection())
return df
except Exception as e:
return f"Error: {str(e)}"
Expand All @@ -85,7 +85,7 @@ def trial_values(self):
def _get_trial_values_table(self):
try:
with self.session_maker() as session:
df = pd.read_sql_query("SELECT * FROM trial_values", session.connection())
df = pd.read_sql_query(text("SELECT * FROM trial_values"), session.connection())
return df
except Exception as e:
return f"Error: {str(e)}"
Expand All @@ -97,7 +97,7 @@ def trial_system_attributes(self):
def _get_trial_system_attributes_table(self):
try:
with self.session_maker() as session:
df = pd.read_sql_query("SELECT * FROM trial_system_attributes", session.connection())
df = pd.read_sql_query(text("SELECT * FROM trial_system_attributes"), session.connection())
return df
except Exception as e:
return f"Error: {str(e)}"
Expand All @@ -109,7 +109,7 @@ def trial_system_attributes(self):
def _get_trial_system_attributes_table(self):
try:
with self.session_maker() as session:
df = pd.read_sql_query("SELECT * FROM trial_system_attributes", session.connection())
df = pd.read_sql_query(text("SELECT * FROM trial_system_attributes"), session.connection())
return df
except Exception as e:
return f"Error: {str(e)}"
Expand All @@ -121,7 +121,7 @@ def version_info(self):
def _get_version_info_table(self):
try:
with self.session_maker() as session:
df = pd.read_sql_query("SELECT * FROM version_info", session.connection())
df = pd.read_sql_query(text("SELECT * FROM version_info"), session.connection())
return df
except Exception as e:
return f"Error: {str(e)}"
Expand All @@ -133,7 +133,7 @@ def study_directions(self):
def _get_study_directions_table(self):
try:
with self.session_maker() as session:
df = pd.read_sql_query("SELECT * FROM study_directions", session.connection())
df = pd.read_sql_query(text("SELECT * FROM study_directions"), session.connection())
return df
except Exception as e:
return f"Error: {str(e)}"
Expand All @@ -145,7 +145,7 @@ def study_user_attributes(self):
def _get_study_user_attributes_table(self):
try:
with self.session_maker() as session:
df = pd.read_sql_query("SELECT * FROM study_user_attributes", session.connection())
df = pd.read_sql_query(text("SELECT * FROM study_user_attributes"), session.connection())
return df
except Exception as e:
return f"Error: {str(e)}"
Expand All @@ -157,7 +157,7 @@ def study_system_attributes(self):
def _get_study_system_attributes_table(self):
try:
with self.session_maker() as session:
df = pd.read_sql_query("SELECT * FROM study_system_attributes", session.connection())
df = pd.read_sql_query(text("SELECT * FROM study_system_attributes"), session.connection())
return df
except Exception as e:
return f"Error: {str(e)}"
Expand All @@ -169,7 +169,7 @@ def trial_user_attributes(self):
def _get_trial_user_attributes_table(self):
try:
with self.session_maker() as session:
df = pd.read_sql_query("SELECT * FROM trial_user_attributes", session.connection())
df = pd.read_sql_query(text("SELECT * FROM trial_user_attributes"), session.connection())
return df
except Exception as e:
return f"Error: {str(e)}"
Expand All @@ -181,7 +181,7 @@ def trial_intermediate_values(self):
def _get_trial_intermediate_values_table(self):
try:
with self.session_maker() as session:
df = pd.read_sql_query("SELECT * FROM trial_intermediate_values", session.connection())
df = pd.read_sql_query(text("SELECT * FROM trial_intermediate_values"), session.connection())
return df
except Exception as e:
return f"Error: {str(e)}"
Expand All @@ -193,7 +193,7 @@ def trial_heartbeats(self):
def _get_trial_heartbeats_table(self):
try:
with self.session_maker() as session:
df = pd.read_sql_query("SELECT * FROM trial_heartbeats", session.connection())
df = pd.read_sql_query(text("SELECT * FROM trial_heartbeats"), session.connection())
return df
except Exception as e:
return f"Error: {str(e)}"
Expand All @@ -205,7 +205,7 @@ def alembic_version(self):
def _get_alembic_version_table(self):
try:
with self.session_maker() as session:
df = pd.read_sql_query("SELECT * FROM alembic_version", session.connection())
df = pd.read_sql_query(text("SELECT * FROM alembic_version"), session.connection())
return df
except Exception as e:
return f"Error: {str(e)}"
Expand Down