diff options
author | Steve McIntyre <steve.mcintyre@linaro.org> | 2014-12-05 17:17:09 +0000 |
---|---|---|
committer | Steve McIntyre <steve.mcintyre@linaro.org> | 2014-12-05 17:17:09 +0000 |
commit | e13711078936884dcce49f815720726430a2aa4b (patch) | |
tree | ee27cbfa58eb10dfa61e91be803a1809b937e3d9 /db | |
parent | 4b91813381216e490b6ffa1a105bbbd7fb75833a (diff) |
More consistemcy updates to the database layer
Add "RETURNING foo_id" to all UPDATE calls, so we can return it to
callers. Also make sure that all UPDATEs get commits to match.
Change-Id: I2fa417dcb38a5612bbb9dac102718430d73343af
Diffstat (limited to 'db')
-rw-r--r-- | db/db.py | 14 |
1 files changed, 10 insertions, 4 deletions
@@ -334,7 +334,7 @@ class VlanDB: if port is None: raise InputError("Port ID %d does not exist" % int(port_id)) try: - sql = "UPDATE port SET is_locked=%s WHERE port_id=%s" + sql = "UPDATE port SET is_locked=%s WHERE port_id=%s RETURNING port_id" data = (is_locked, port_id) self.cursor.execute(sql, data) port_id = self.cursor.fetchone()[0] @@ -357,7 +357,7 @@ class VlanDB: else: raise InputError("Port mode %s is not valid" % mode) try: - sql = "UPDATE port SET is_trunk=%s WHERE port_id=%s" + sql = "UPDATE port SET is_trunk=%s WHERE port_id=%s RETURNING port_id" data = (is_trunk, port_id) self.cursor.execute(sql, data) port_id = self.cursor.fetchone()[0] @@ -380,12 +380,15 @@ class VlanDB: raise InputError("VLAN ID %d does not exist" % int(vlan_id)) try: - sql = "UPDATE port SET current_vlan_id=%s WHERE port_id=%s" + sql = "UPDATE port SET current_vlan_id=%s WHERE port_id=%s RETURNING port_id" data = (vlan_id, port_id) self.cursor.execute(sql, data) + port_id = self.cursor.fetchone()[0] + self.connection.commit() except: self.connection.rollback() raise + return port_id def restore_default_vlan(self, port_id): port = self.get_port_by_id(port_id) @@ -396,12 +399,15 @@ class VlanDB: raise CriticalError("The port is locked") try: - sql = "UPDATE port SET current_vlan_id=base_vlan_id WHERE port_id=%s" + sql = "UPDATE port SET current_vlan_id=base_vlan_id WHERE port_id=%s RETURNING port_id" data = (port_id,) self.cursor.execute(sql, data) + port_id = self.cursor.fetchone()[0] + self.connection.commit() except: self.connection.rollback() raise + return port_id def _dump_table(self, table): result = [] |