Skip to content

Commit

Permalink
chore(yml): add database types to connectors prompts (#1039)
Browse files Browse the repository at this point in the history
* chore(yml): add types of yml connectors

* Update pandasai/connectors/base.py

Co-authored-by: Massimiliano Pronesti <[email protected]>

* update requested changes

---------

Co-authored-by: Massimiliano Pronesti <[email protected]>
  • Loading branch information
ArslanSaleem and mspronesti committed Mar 15, 2024
1 parent bcf80a5 commit d3b2fd9
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 4 deletions.
4 changes: 4 additions & 0 deletions pandasai/connectors/airtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,7 @@ def column_hash(self):
columns_str = "|".join(self._instance.columns)
columns_str += f"WHERE{self._build_formula()}"
return hashlib.sha256(columns_str.encode("utf-8")).hexdigest()

@property
def type(self):
return "SQL"
9 changes: 8 additions & 1 deletion pandasai/connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,13 @@ def pandas_df(self):
"""
raise NotImplementedError

@property
def type(self):
"""
type of the connector
"""
pass

def equals(self, other):
return self.__dict__ == other.__dict__

Expand Down Expand Up @@ -270,7 +277,7 @@ def to_string(
self,
extras={
"index": index,
"type": "sql" if is_direct_sql else "pd.DataFrame",
"type": "pd.DataFrame",
"is_direct_sql": is_direct_sql,
},
type_=serializer,
Expand Down
4 changes: 4 additions & 0 deletions pandasai/connectors/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def fallback_name(self):
"""
pass

@property
def type(self):
return "pd.DataFrame"

def equals(self, other: BaseConnector):
"""
Return whether the data source that the connector is connected to is
Expand Down
4 changes: 4 additions & 0 deletions pandasai/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,10 @@ def execute_direct_sql_query(self, sql_query):
def cs_table_name(self):
return self.config.table

@property
def type(self):
return self.config.dialect


class SqliteConnector(SQLConnector):
"""
Expand Down
19 changes: 19 additions & 0 deletions pandasai/ee/connectors/google_big_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,22 @@ def __repr__(self):
f"<{self.__class__.__name__} dialect={self.config.dialect} "
f"projectid= {self.config.projectID} database={self.config.database} >"
)

def equals(self, other):
if isinstance(other, self.__class__):
return (
self.config.dialect,
self.config.driver,
self.config.credentials_path,
self.config.credentials_base64,
self.config.database,
self.config.projectID,
) == (
other.config.dialect,
other.config.driver,
other.config.credentials_path,
other.config.credentials_base64,
other.config.database,
other.config.projectID,
)
return False
6 changes: 5 additions & 1 deletion pandasai/helpers/dataframe_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ def convert_df_to_json(self, df: pd.DataFrame, extras: dict) -> dict:
df_info = {
"name": df.name,
"description": df.description,
"type": extras["type"],
"type": (
df.type
if "is_direct_sql" in extras and extras["is_direct_sql"]
else extras["type"]
),
}
# Add DataFrame details to the result
data = {
Expand Down
48 changes: 48 additions & 0 deletions tests/unit_tests/connectors/test_google_big_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,51 @@ def test_fallback_name_property(self):
# Test fallback_name property
fallback_name = self.connector.fallback_name
self.assertEqual(fallback_name, "yourtable")

@patch("pandasai.ee.connectors.google_big_query.create_engine", autospec=True)
def test_constructor_and_properties_equal_func(self, mock_create_engine):
self.mock_engine = Mock()
self.mock_connection = Mock()
self.mock_engine.connect.return_value = self.mock_connection
mock_create_engine.return_value = self.mock_engine

self.config = GoogleBigQueryConnectorConfig(
dialect="bigquery",
database="database",
table="yourtable",
credentials_base64="base64_str",
projectID="project_id",
).dict()

self.connector = GoogleBigQueryConnector(self.config)
connector_2 = GoogleBigQueryConnector(self.config)

assert self.connector.equals(connector_2)

@patch("pandasai.ee.connectors.google_big_query.create_engine", autospec=True)
def test_constructor_and_properties_not_equal_func(self, mock_create_engine):
self.mock_engine = Mock()
self.mock_connection = Mock()
self.mock_engine.connect.return_value = self.mock_connection
mock_create_engine.return_value = self.mock_engine

self.config = GoogleBigQueryConnectorConfig(
dialect="bigquery",
database="database",
table="yourtable",
credentials_base64="base64_str",
projectID="project_id",
).dict()

config2 = GoogleBigQueryConnectorConfig(
dialect="bigquery",
database="database2",
table="yourtable",
credentials_base64="base64_str",
projectID="project_id",
).dict()

self.connector = GoogleBigQueryConnector(self.config)
connector_2 = GoogleBigQueryConnector(config2)

assert not self.connector.equals(connector_2)
42 changes: 40 additions & 2 deletions tests/unit_tests/connectors/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd

from pandasai.connectors.sql import (
MySQLConnector,
PostgreSQLConnector,
SQLConnector,
SQLConnectorConfig,
Expand Down Expand Up @@ -184,8 +185,8 @@ def test_equals_different_configs(
def test_equals_different_connector(self, mock_init_connection):
# Define your ConnectorConfig instance here
self.config = SQLConnectorConfig(
dialect="mysql",
driver="pymysql",
dialect="postgresql",
driver="psycopg2",
username="your_username_differ",
password="your_password",
host="your_host",
Expand All @@ -199,3 +200,40 @@ def test_equals_different_connector(self, mock_init_connection):
connector_2 = PostgreSQLConnector(self.config)

assert not self.connector.equals(connector_2)

@patch("pandasai.connectors.SQLConnector._init_connection")
def test_equals_connector_type(self, mock_init_connection):
# Define your ConnectorConfig instance here
config = {
"username": "your_username_differ",
"password": "your_password",
"host": "your_host",
"port": 443,
"database": "your_database",
"table": "your_table",
"where": [["column_name", "=", "value"]],
}

# Create an instance of SQLConnector
connector_2 = PostgreSQLConnector(config)

assert connector_2.type == "postgresql"

@patch("pandasai.connectors.SQLConnector._init_connection")
def test_equals_sql_connector_type(self, mock_init_connection):
# Define your ConnectorConfig instance here

config = {
"username": "your_username_differ",
"password": "your_password",
"host": "your_host",
"port": 443,
"database": "your_database",
"table": "your_table",
"where": [["column_name", "=", "value"]],
}

# Create an instance of SQLConnector
connector_2 = MySQLConnector(config)

assert connector_2.type == "mysql"

0 comments on commit d3b2fd9

Please sign in to comment.