Skip to content

Commit

Permalink
feat: Add on error rollback
Browse files Browse the repository at this point in the history
  • Loading branch information
gma2th committed Aug 5, 2018
1 parent ac45a6d commit e2e3917
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 3 deletions.
3 changes: 3 additions & 0 deletions pgcli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(self, force_passwd_prompt=False, never_passwd_prompt=False,
self.multi_line = c['main'].as_bool('multi_line')
self.multiline_mode = c['main'].get('multi_line_mode', 'psql')
self.autocommit_mode = c['main'].as_bool('autocommit_mode')
self.on_error_rollback = c['main'].as_bool('on_error_rollback')
self.vi_mode = c['main'].as_bool('vi')
self.auto_expand = auto_vertical_output or c['main'].as_bool(
'auto_expand')
Expand Down Expand Up @@ -462,6 +463,7 @@ def connect(self, database='', host='', user='', port='', passwd='',
try:
pgexecute = PGExecute(database, user, passwd, host, port, dsn,
autocommit_mode=self.autocommit_mode,
on_error_rollback=self.on_error_rollback,
application_name='pgcli', **kwargs)
except (OperationalError, InterfaceError) as e:
if ('no password supplied' in utf8tounicode(e.args[0]) and
Expand All @@ -471,6 +473,7 @@ def connect(self, database='', host='', user='', port='', passwd='',
type=str)
pgexecute = PGExecute(database, user, passwd, host, port, dsn,
autocommit_mode=self.autocommit_mode,
on_error_rollback=self.on_error_rollback,
application_name='pgcli',
**kwargs)
else:
Expand Down
6 changes: 5 additions & 1 deletion pgcli/pgclirc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ multi_line_mode = psql
# Similar to `\set AUTOCOMMIT on/off` in psql.
autocommit_mode = True

# When set to on, if a statement in a transaction block generates an error, the error is ignored
# and the transaction continues. Similar `\set ON_ERROR_ROLLBACK` on in psql.
on_error_rollback = False

# Destructive warning mode will alert you before executing a sql statement
# that may cause harm to the database such as "drop table", "drop database"
# or "shutdown".
Expand Down Expand Up @@ -144,7 +148,7 @@ null_string = '<null>'
# manage pager on startup
enable_pager = True

# Use keyring to automatically save and load password in a secure manner
# Use keyring to automatically save and load password in a secure manner
keyring = True

# Custom colors for the completion menu, toolbar, etc.
Expand Down
60 changes: 58 additions & 2 deletions pgcli/pgexecute.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,32 @@
# TODO: Get default timeout from pgclirc?
_WAIT_SELECT_TIMEOUT = 1

TRANSACTION_CONTROL_COMMANDS = (
'abort',
'begin',
'commit',
'end',
'prepare transaction',
'rollback',
'start transaction',
)

COMMANDS_NOT_ALLOWED_IN_TRANSACTION = (
'alter system',
'cluster',
'create database',
'create index concurrently',
'create tablespace',
'create unique index concurrently',
'discard all',
'drop database',
'drop index concurrently',
'drop tablespace',
'reindex database',
'reindex system',
'vacuum',
)


def _wait_select(conn):
"""
Expand Down Expand Up @@ -143,6 +169,30 @@ def init_connection_from_parameters(dbname, user, password, host, port, **kwargs
return conn


def execute_sql_with_on_error_rollback(conn, user_cur, statement_):
statement = ' '.join(statement_.lower().strip().split())

if statement.startswith(COMMANDS_NOT_ALLOWED_IN_TRANSACTION + TRANSACTION_CONTROL_COMMANDS):
user_cur.execute(statement)
return

if (conn.get_transaction_status() == psycopg2.extensions.TRANSACTION_STATUS_IDLE and
not conn.autocommit):
with conn.cursor() as pgcli_cur:
pgcli_cur.execute('BEGIN')
with conn.cursor() as pgcli_cur:
pgcli_cur.execute('SAVEPOINT pgcli_user_conn_tmp_savepoint')
try:
user_cur.execute(statement)
except psycopg2.DatabaseError:
with conn.cursor() as pgcli_cur:
pgcli_cur.execute('ROLLBACK TO pgcli_user_conn_tmp_savepoint')
raise
else:
with conn.cursor() as pgcli_cur:
pgcli_cur.execute('RELEASE pgcli_user_conn_tmp_savepoint')


class PGExecute(object):

# The boolean argument to the current_schemas function indicates whether
Expand Down Expand Up @@ -208,7 +258,8 @@ class PGExecute(object):

version_query = "SELECT version();"

def __init__(self, database, user, password, host, port, dsn, autocommit_mode=True, **kwargs):
def __init__(self, database, user, password, host, port, dsn,
autocommit_mode=True, on_error_rollback=False, **kwargs):
self.dbname = database
self.user = user
self.password = password
Expand All @@ -217,6 +268,8 @@ def __init__(self, database, user, password, host, port, dsn, autocommit_mode=Tr
self.dsn = dsn
self.extra_args = {k: unicode2utf8(v) for k, v in kwargs.items()}
self.server_version = None
self.autocommit_mode = autocommit_mode
self.on_error_rollback = on_error_rollback
self.connect()
# user_conn is the connection used to execute user queries
self.user_conn = self.init_connection(self.dbname, self.user, self.password, self.host,
Expand Down Expand Up @@ -403,7 +456,10 @@ def execute_normal_sql(self, split_sql):
"""Returns tuple (title, rows, headers, status)"""
_logger.debug('Regular sql statement. sql: %r', split_sql)
cur = self.user_conn.cursor()
cur.execute(split_sql)
if not self.autocommit_mode and self.on_error_rollback:
execute_sql_with_on_error_rollback(self.user_conn, cur, split_sql)
else:
cur.execute(split_sql)

# conn.notices persist between queies, we use pop to clear out the list
title = ''
Expand Down

0 comments on commit e2e3917

Please sign in to comment.