diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index 4ccdb3612..210fd3f10 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -99,7 +99,8 @@ def register_json_typecasters(conn, loads_fn): psycopg2.extras.register_json(conn, loads=loads_fn, name=name) available.add(name) except psycopg2.ProgrammingError: - pass + if not conn.autocommit: + conn.rollback() return available @@ -116,10 +117,32 @@ def register_hstore_typecaster(conn): cur.execute("SELECT 'hstore'::regtype::oid") oid = cur.fetchone()[0] ext.register_type(ext.new_type((oid,), "HSTORE", ext.UNICODE)) + except psycopg2.ProgrammingError: + if not conn.autocommit: + conn.rollback() except Exception: pass +def init_connection_from_dsn(dsn, password): + if password: + dsn = "{0} password={1}".format(dsn, password) + conn = psycopg2.connect(dsn=unicode2utf8(dsn)) + return conn + + +def init_connection_from_parameters(dbname, user, password, host, port, **kwargs): + conn = psycopg2.connect( + database=unicode2utf8(dbname), + user=unicode2utf8(user), + password=unicode2utf8(password), + host=unicode2utf8(host), + port=unicode2utf8(port), + **kwargs + ) + return conn + + class PGExecute(object): # The boolean argument to the current_schemas function indicates whether @@ -195,7 +218,9 @@ def __init__(self, database, user, password, host, port, dsn, autocommit_mode=Tr self.extra_args = {k: unicode2utf8(v) for k, v in kwargs.items()} self.server_version = None self.connect() - self.user_conn = self.get_new_connection(autocommit=autocommit_mode) + # user_conn is the connection used to execute user queries + self.user_conn = self.init_connection(self.dbname, self.user, self.password, self.host, + self.port, self.dsn, autocommit_mode, **kwargs) def get_server_version(self): if self.server_version: @@ -214,9 +239,24 @@ def get_server_version(self): self.server_version = '' return self.server_version + def init_connection(self, database=None, user=None, password=None, host=None, + port=None, dsn=None, autocommit_mode=True, **kwargs): + if dsn: + conn = init_connection_from_dsn(dsn, password) + else: + conn = init_connection_from_parameters( + database, user, password, host, port, **kwargs) + conn.set_client_encoding('utf8') + # Need to be set before any executed query + conn.set_session(autocommit=autocommit_mode) + register_date_typecasters(conn) + register_json_typecasters(conn, self._json_typecaster) + register_hstore_typecaster(conn) + return conn + def connect(self, database=None, user=None, password=None, host=None, port=None, dsn=None, **kwargs): - + """Setup pgcli internal connection.""" db = (database or self.dbname) user = (user or self.user) password = (password or self.password) @@ -224,73 +264,34 @@ def connect(self, database=None, user=None, password=None, host=None, port = (port or self.port) dsn = (dsn or self.dsn) kwargs = (kwargs or self.extra_args) - pid = -1 - if dsn: - if password: - dsn = "{0} password={1}".format(dsn, password) - conn = psycopg2.connect(dsn=unicode2utf8(dsn)) - cursor = conn.cursor() - else: - conn = psycopg2.connect( - database=unicode2utf8(db), - user=unicode2utf8(user), - password=unicode2utf8(password), - host=unicode2utf8(host), - port=unicode2utf8(port), - **kwargs) - cursor = conn.cursor() - - conn.set_client_encoding('utf8') + conn = self.init_connection( + db, user, password, host, port, dsn, **kwargs) if hasattr(self, 'conn'): self.conn.close() self.conn = conn - self.conn.autocommit = True - if dsn: - # When we connect using a DSN, we don't really know what db, - # user, etc. we connected to. Let's read it. - # Note: moved this after setting autocommit because of #664. - dsn_parameters = conn.get_dsn_parameters() - db = dsn_parameters['dbname'] - user = dsn_parameters['user'] - host = dsn_parameters['host'] - port = dsn_parameters['port'] - - self.dbname = db - self.user = user + # Ensure class attribute are set + # When we connect using a DSN, we don't really know what db, + # user, etc. we connected to. Let's read it. + # Note: moved this after setting autocommit because of #664. + dsn_parameters = conn.get_dsn_parameters() + self.dbname = dsn_parameters['dbname'] + self.user = dsn_parameters['user'] + self.host = dsn_parameters['host'] + self.port = dsn_parameters['port'] self.password = password - self.host = host - self.port = port - if not self.host: self.host = self.get_socket_directory() + cursor = conn.cursor() cursor.execute("SHOW ALL") - db_parameters = dict(name_val_desc[:2] for name_val_desc in cursor.fetchall()) - pid = self._select_one(cursor, 'select pg_backend_pid()')[0] self.pid = pid - self.superuser = db_parameters.get('is_superuser') == '1' - register_date_typecasters(conn) - register_json_typecasters(self.conn, self._json_typecaster) - register_hstore_typecaster(self.conn) - - def get_new_connection(self, autocommit=True): - conn = psycopg2.connect( - database=unicode2utf8(self.dbname), - user=unicode2utf8(self.user), - password=unicode2utf8(self.password), - host=unicode2utf8(self.host), - port=unicode2utf8(self.port) - ) - conn.set_client_encoding('utf8') - conn.set_session(autocommit=autocommit) - register_date_typecasters(conn) - register_json_typecasters(conn, self._json_typecaster) - register_hstore_typecaster(conn) - return conn + db_parameters = dict(name_val_desc[:2] + for name_val_desc in cursor.fetchall()) + self.superuser = db_parameters.get('is_superuser') == '1' def _select_one(self, cur, sql):