| #! /usr/bin/python |
| |
| # Copyright 2014-2018 Linaro Limited |
| # Authors: Dave Pigott <dave.pigott@linaro.org>, |
| # Steve McIntyre <steve.mcintyre@linaro.org> |
| # |
| # This program is free software; you can redistribute it and/or modify |
| # it under the terms of the GNU General Public License as published by |
| # the Free Software Foundation; either version 2 of the License, or |
| # (at your option) any later version. |
| # |
| # This program is distributed in the hope that it will be useful, |
| # but WITHOUT ANY WARRANTY; without even the implied warranty of |
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| # GNU General Public License for more details. |
| # |
| # You should have received a copy of the GNU General Public License |
| # along with this program; if not, write to the Free Software |
| # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, |
| # MA 02110-1301, USA. |
| |
| import psycopg2 |
| import psycopg2.extras |
| import datetime, os, sys |
| import logging |
| |
| TRUNK_ID_NONE = -1 |
| |
| # The schema version that this code expects. If it finds an older version (or |
| # no version!) at startup, it will auto-migrate to the latest version |
| # |
| # Version 0: Base, no version found |
| # |
| # Version 1: No changes, except adding the version and coping with upgrade |
| # |
| # Version 2: Add "lock_reason" field in the port table, and code to deal with |
| # it |
| DATABASE_SCHEMA_VERSION = 2 |
| |
| if __name__ == '__main__': |
| vlandpath = os.path.abspath(os.path.normpath(os.path.dirname(sys.argv[0]))) |
| sys.path.insert(0, vlandpath) |
| sys.path.insert(0, "%s/.." % vlandpath) |
| |
| from errors import CriticalError, InputError, NotFoundError |
| |
| class VlanDB: |
| def __init__(self, db_name="vland", username="vland", readonly=True): |
| try: |
| self.connection = psycopg2.connect(database=db_name, user=username) |
| # Create first cursor for normal usage - returns tuples |
| self.cursor = self.connection.cursor(cursor_factory=psycopg2.extras.NamedTupleCursor) |
| # Create second cursor for full-row lookups - returns a dict |
| # instead, much more useful in the admin interface |
| self.dictcursor = self.connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor) |
| if not readonly: |
| self._init_state() |
| except Exception as e: |
| logging.error("Failed to access database: %s", e) |
| raise |
| |
| def __del__(self): |
| self.cursor.close() |
| self.dictcursor.close() |
| self.connection.close() |
| |
| # Create the state table (if needed) and add its only record |
| # |
| # Use the stored record of the expected database schema to track what |
| # version the on-disk database is, and upgrade it to match the current code |
| # if necessary. |
| def _init_state(self): |
| found_db = False |
| current_db_version = 0 |
| try: |
| sql = "SELECT * FROM state" |
| self.cursor.execute(sql) |
| found_db = True |
| except psycopg2.ProgrammingError: |
| self.connection.commit() # state doesn't exist; clear error |
| sql = "CREATE TABLE state (last_modified TIMESTAMP, schema_version INTEGER)" |
| self.cursor.execute(sql) |
| # We've just created a version 1 database |
| current_db_version = 1 |
| |
| if found_db: |
| # Grab the version of the database we have |
| try: |
| sql = "SELECT schema_version FROM state" |
| self.cursor.execute(sql) |
| current_db_version = self.cursor.fetchone()[0] |
| # No version found ==> we have "version 0" |
| except psycopg2.ProgrammingError: |
| self.connection.commit() # state doesn't exist; clear error |
| current_db_version = 0 |
| |
| # Now delete the existing state record, we'll write a new one in a |
| # moment |
| self.cursor.execute('DELETE FROM state') |
| logging.info("Found a database, version %d", current_db_version) |
| |
| # Apply upgrades here! |
| if current_db_version < 1: |
| logging.info("Upgrading database to match schema version 1") |
| sql = "ALTER TABLE state ADD schema_version INTEGER" |
| self.cursor.execute(sql) |
| logging.info("Schema version 1 upgrade successful") |
| |
| if current_db_version < 2: |
| logging.info("Upgrading database to match schema version 2") |
| sql = "ALTER TABLE port ADD lock_reason VARCHAR(64)" |
| self.cursor.execute(sql) |
| logging.info("Schema version 2 upgrade successful") |
| |
| sql = "INSERT INTO state (last_modified, schema_version) VALUES (%s, %s)" |
| data = (datetime.datetime.now(), DATABASE_SCHEMA_VERSION) |
| self.cursor.execute(sql, data) |
| self.connection.commit() |
| |
| # Create a new switch in the database. Switches are really simple |
| # devices - they're just containers for ports. |
| # |
| # Constraints: |
| # Switches must be uniquely named |
| def create_switch(self, name): |
| |
| switch_id = self.get_switch_id_by_name(name) |
| if switch_id is not None: |
| raise InputError("Switch name %s already exists" % name) |
| |
| try: |
| sql = "INSERT INTO switch (name) VALUES (%s) RETURNING switch_id" |
| data = (name, ) |
| self.cursor.execute(sql, data) |
| switch_id = self.cursor.fetchone()[0] |
| self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),)) |
| self.connection.commit() |
| except: |
| self.connection.rollback() |
| raise |
| |
| return switch_id |
| |
| # Create a new port in the database. Three of the fields are |
| # created with default values (is_locked, is_trunk, trunk_id) |
| # here, and should be updated separately if desired. For the |
| # current_vlan_id and base_vlan_id fields, *BE CAREFUL* that you |
| # have already looked up the correct VLAN_ID for each. This is |
| # *NOT* the same as the VLAN tag (likely to be 1). You Have Been |
| # Warned! |
| # |
| # Constraints: |
| # 1. The switch referred to must already exist |
| # 2. The VLANs mentioned here must already exist |
| # 3. (Switch/name) must be unique |
| # 4. (Switch/number) must be unique |
| def create_port(self, switch_id, name, number, current_vlan_id, base_vlan_id): |
| |
| switch = self.get_switch_by_id(switch_id) |
| if switch is None: |
| raise NotFoundError("Switch ID %d does not exist" % int(switch_id)) |
| |
| for vlan_id in (current_vlan_id, base_vlan_id): |
| vlan = self.get_vlan_by_id(vlan_id) |
| if vlan is None: |
| raise NotFoundError("VLAN ID %d does not exist" % int(vlan_id)) |
| |
| port_id = self.get_port_by_switch_and_name(switch_id, name) |
| if port_id is not None: |
| raise InputError("Already have a port %s on switch ID %d" % (name, int(switch_id))) |
| |
| port_id = self.get_port_by_switch_and_number(switch_id, int(number)) |
| if port_id is not None: |
| raise InputError("Already have a port %d on switch ID %d" % (int(number), int(switch_id))) |
| |
| try: |
| sql = "INSERT INTO port (name, number, switch_id, is_locked, lock_reason, is_trunk, current_vlan_id, base_vlan_id, trunk_id) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) RETURNING port_id" |
| data = (name, number, switch_id, |
| False, "", |
| False, |
| current_vlan_id, base_vlan_id, TRUNK_ID_NONE) |
| self.cursor.execute(sql, data) |
| port_id = self.cursor.fetchone()[0] |
| self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),)) |
| self.connection.commit() |
| except: |
| self.connection.rollback() |
| raise |
| |
| return port_id |
| |
| # Create a new vlan in the database. We locally add a creation |
| # timestamp, for debug purposes. If vlans seems to be sticking |
| # around, we'll be able to see when they were created. |
| # |
| # Constraints: |
| # Names and tags must be unique |
| # Tags must be in the range 1-4095 (802.1q spec) |
| # Names can be any free-form text, length 1-32 characters |
| def create_vlan(self, name, tag, is_base_vlan): |
| |
| if int(tag) < 1 or int(tag) > 4095: |
| raise InputError("VLAN tag %d is outside of the valid range (1-4095)" % int(tag)) |
| |
| if (len(name) < 1) or (len(name) > 32): |
| raise InputError("VLAN name %s is invalid (must be 1-32 chars)" % name) |
| |
| vlan_id = self.get_vlan_id_by_name(name) |
| if vlan_id is not None: |
| raise InputError("VLAN name %s is already in use" % name) |
| |
| vlan_id = self.get_vlan_id_by_tag(tag) |
| if vlan_id is not None: |
| raise InputError("VLAN tag %d is already in use" % int(tag)) |
| |
| try: |
| dt = datetime.datetime.now() |
| sql = "INSERT INTO vlan (name, tag, is_base_vlan, creation_time) VALUES (%s, %s, %s, %s) RETURNING vlan_id" |
| data = (name, tag, is_base_vlan, dt) |
| self.cursor.execute(sql, data) |
| vlan_id = self.cursor.fetchone()[0] |
| self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),)) |
| self.connection.commit() |
| except: |
| self.connection.rollback() |
| raise |
| |
| return vlan_id |
| |
| # Create a new trunk in the database, linking two ports. Trunks |
| # are really simple objects for our use - they're just containers |
| # for 2 ports. |
| # |
| # Constraints: |
| # 1. Both ports listed must already exist. |
| # 2. Both ports must be in trunk mode. |
| # 3. Both must not be locked. |
| # 4. Both must not already be in a trunk. |
| def create_trunk(self, port_id1, port_id2): |
| |
| for port_id in (port_id1, port_id2): |
| port = self.get_port_by_id(int(port_id)) |
| if port is None: |
| raise NotFoundError("Port ID %d does not exist" % int(port_id)) |
| if not port['is_trunk']: |
| raise InputError("Port ID %d is not in trunk mode" % int(port_id)) |
| if port['is_locked']: |
| raise InputError("Port ID %d is locked" % int(port_id)) |
| if port['trunk_id'] != TRUNK_ID_NONE: |
| raise InputError("Port ID %d is already on trunk ID %d" % (int(port_id), int(port['trunk_id']))) |
| |
| try: |
| # Add the trunk itself |
| dt = datetime.datetime.now() |
| sql = "INSERT INTO trunk (creation_time) VALUES (%s) RETURNING trunk_id" |
| data = (dt, ) |
| self.cursor.execute(sql, data) |
| trunk_id = self.cursor.fetchone()[0] |
| self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),)) |
| self.connection.commit() |
| # And update the ports |
| for port_id in (port_id1, port_id2): |
| self._set_port_trunk(port_id, trunk_id) |
| except: |
| self.delete_trunk(trunk_id) |
| raise |
| |
| return trunk_id |
| |
| # Internal helper function |
| def _delete_row(self, table, field, value): |
| try: |
| sql = "DELETE FROM %s WHERE %s = %s" % (table, field, '%s') |
| data = (value,) |
| self.cursor.execute(sql, data) |
| self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),)) |
| self.connection.commit() |
| except: |
| self.connection.rollback() |
| raise |
| |
| # Delete the specified switch |
| # |
| # Constraints: |
| # 1. The switch must exist |
| # 2. The switch may not be referenced by any ports - |
| # delete them first! |
| def delete_switch(self, switch_id): |
| switch = self.get_switch_by_id(switch_id) |
| if switch is None: |
| raise NotFoundError("Switch ID %d does not exist" % int(switch_id)) |
| ports = self.get_ports_by_switch(switch_id) |
| if ports is not None: |
| raise InputError("Cannot delete switch ID %d when it still has %d ports" % |
| (int(switch_id), len(ports))) |
| self._delete_row("switch", "switch_id", switch_id) |
| return switch_id |
| |
| # Delete the specified port |
| # |
| # Constraints: |
| # 1. The port must exist |
| # 2. The port must not be locked |
| def delete_port(self, port_id): |
| port = self.get_port_by_id(port_id) |
| if port is None: |
| raise NotFoundError("Port ID %d does not exist" % int(port_id)) |
| if port['is_locked']: |
| raise InputError("Cannot delete port ID %d as it is locked" % int(port_id)) |
| self._delete_row("port", "port_id", port_id) |
| return port_id |
| |
| # Delete the specified VLAN |
| # |
| # Constraints: |
| # 1. The VLAN must exist |
| # 2. The VLAN may not contain any ports - move or delete them first! |
| def delete_vlan(self, vlan_id): |
| vlan = self.get_vlan_by_id(vlan_id) |
| if vlan is None: |
| raise NotFoundError("VLAN ID %d does not exist" % int(vlan_id)) |
| ports = self.get_ports_by_current_vlan(vlan_id) |
| if ports is not None: |
| raise InputError("Cannot delete VLAN ID %d when it still has %d ports" % |
| (int(vlan_id), len(ports))) |
| ports = self.get_ports_by_base_vlan(vlan_id) |
| if ports is not None: |
| raise InputError("Cannot delete VLAN ID %d when it still has %d ports" % |
| (int(vlan_id), len(ports))) |
| self._delete_row("vlan", "vlan_id", vlan_id) |
| return vlan_id |
| |
| # Delete the specified trunk |
| # |
| # Constraints: |
| # 1. The trunk must exist |
| # |
| # Any ports attached will be detached (i.e. moved to trunk TRUNK_ID_NONE) |
| def delete_trunk(self, trunk_id): |
| trunk = self.get_trunk_by_id(trunk_id) |
| if trunk is None: |
| raise NotFoundError("Trunk ID %d does not exist" % int(trunk_id)) |
| ports = self.get_ports_by_trunk(trunk_id) |
| for port_id in ports: |
| self._set_port_trunk(port_id, TRUNK_ID_NONE) |
| self._delete_row("trunk", "trunk_id", trunk_id) |
| return trunk_id |
| |
| # Find the lowest unused VLAN tag and return it |
| # |
| # Constraints: |
| # None |
| def find_lowest_unused_vlan_tag(self): |
| sql = "SELECT tag FROM vlan ORDER BY tag ASC" |
| self.cursor.execute(sql,) |
| |
| # Walk through the list, looking for gaps |
| last = 1 |
| result = None |
| |
| for record in self.cursor: |
| if (record[0] - last) > 1: |
| result = last + 1 |
| break |
| last = record[0] |
| |
| if result is None: |
| result = last + 1 |
| |
| if result > 4093: |
| raise CriticalError("Can't find any VLAN tags remaining for allocation!") |
| |
| return result |
| |
| # Grab one column from one row of a query on one column; useful as |
| # a quick wrapper |
| def _get_element(self, select_field, table, compare_field, value): |
| |
| if value is None: |
| raise ValueError("Asked to look up using None as a key in %s" % compare_field) |
| |
| # We really want to use psycopg's type handling deal with the |
| # (potentially) user-supplied data in the value field, so we |
| # have to pass (sql,data) through to cursor.execute. However, |
| # we can't have psycopg do all the argument substitution here |
| # as it will quote all the params like the table name. That |
| # doesn't work. So, we substitute a "%s" for "%s" here so we |
| # keep it after python's own string substitution. |
| sql = "SELECT %s FROM %s WHERE %s = %s" % (select_field, table, compare_field, "%s") |
| |
| # Now, the next icky thing: we need to make sure that we're |
| # passing a dict so that psycopg2 can pick it apart properly |
| # for its own substitution code. We force this with the |
| # trailing comma here |
| data = (value, ) |
| self.cursor.execute(sql, data) |
| |
| if self.cursor.rowcount > 0: |
| return self.cursor.fetchone()[0] |
| else: |
| return None |
| |
| # Grab one column from one row of a query on 2 columns; useful as |
| # a quick wrapper |
| def _get_element2(self, select_field, table, compare_field1, value1, compare_field2, value2): |
| |
| if value1 is None: |
| raise ValueError("Asked to look up using None as a key in %s" % compare_field1) |
| if value2 is None: |
| raise ValueError("Asked to look up using None as a key in %s" % compare_field2) |
| |
| # We really want to use psycopg's type handling deal with the |
| # (potentially) user-supplied data in the value field, so we |
| # have to pass (sql,data) through to cursor.execute. However, |
| # we can't have psycopg do all the argument substitution here |
| # as it will quote all the params like the table name. That |
| # doesn't work. So, we substitute a "%s" for "%s" here so we |
| # keep it after python's own string substitution. |
| sql = "SELECT %s FROM %s WHERE %s = %s AND %s = %s" % (select_field, table, compare_field1, "%s", compare_field2, "%s") |
| |
| data = (value1, value2) |
| self.cursor.execute(sql, data) |
| |
| if self.cursor.rowcount > 0: |
| return self.cursor.fetchone()[0] |
| else: |
| return None |
| |
| # Grab one column from multiple rows of a query; useful as a quick |
| # wrapper |
| def _get_multi_elements(self, select_field, table, compare_field, value, sort_field): |
| |
| if value is None: |
| raise ValueError("Asked to look up using None as a key in %s" % compare_field) |
| |
| # We really want to use psycopg's type handling deal with the |
| # (potentially) user-supplied data in the value field, so we |
| # have to pass (sql,data) through to cursor.execute. However, |
| # we can't have psycopg do all the argument substitution here |
| # as it will quote all the params like the table name. That |
| # doesn't work. So, we substitute a "%s" for "%s" here so we |
| # keep it after python's own string substitution. |
| sql = "SELECT %s FROM %s WHERE %s = %s ORDER BY %s ASC" % (select_field, table, compare_field, "%s", sort_field) |
| |
| # Now, the next icky thing: we need to make sure that we're |
| # passing a dict so that psycopg2 can pick it apart properly |
| # for its own substitution code. We force this with the |
| # trailing comma here |
| data = (value, ) |
| self.cursor.execute(sql, data) |
| |
| if self.cursor.rowcount > 0: |
| results = [] |
| for record in self.cursor: |
| results.append(record[0]) |
| return results |
| else: |
| return None |
| |
| # Grab one column from multiple rows of a 2-part query; useful as |
| # a wrapper |
| def _get_multi_elements2(self, select_field, table, compare_field1, value1, compare_field2, value2, sort_field): |
| |
| if value1 is None: |
| raise ValueError("Asked to look up using None as a key in %s" % compare_field1) |
| if value2 is None: |
| raise ValueError("Asked to look up using None as a key in %s" % compare_field2) |
| |
| # We really want to use psycopg's type handling deal with the |
| # (potentially) user-supplied data in the value field, so we |
| # have to pass (sql,data) through to cursor.execute. However, |
| # we can't have psycopg do all the argument substitution here |
| # as it will quote all the params like the table name. That |
| # doesn't work. So, we substitute a "%s" for "%s" here so we |
| # keep it after python's own string substitution. |
| sql = "SELECT %s FROM %s WHERE %s = %s AND %s = %s ORDER by %s ASC" % (select_field, table, compare_field1, "%s", compare_field2, "%s", sort_field) |
| |
| data = (value1, value2) |
| self.cursor.execute(sql, data) |
| |
| if self.cursor.rowcount > 0: |
| results = [] |
| for record in self.cursor: |
| results.append(record[0]) |
| return results |
| else: |
| return None |
| |
| # Simple lookup: look up a switch by ID, and return all the |
| # details of that switch. |
| # |
| # Returns None on failure. |
| def get_switch_by_id(self, switch_id): |
| return self._get_row("switch", "switch_id", int(switch_id)) |
| |
| # Simple lookup: look up a switch by name, and return the ID of |
| # that switch. |
| # |
| # Returns None on failure. |
| def get_switch_id_by_name(self, name): |
| return self._get_element("switch_id", "switch", "name", name) |
| |
| # Simple lookup: look up a switch by ID, and return the name of |
| # that switch. |
| # |
| # Returns None on failure. |
| def get_switch_name_by_id(self, switch_id): |
| return self._get_element("name", "switch", "switch_id", int(switch_id)) |
| |
| # Simple lookup: look up a port by ID, and return all the details |
| # of that port. |
| # |
| # Returns None on failure. |
| def get_port_by_id(self, port_id): |
| return self._get_row("port", "port_id", int(port_id)) |
| |
| # Simple lookup: look up a switch by ID, and return the IDs of all |
| # the ports on that switch. |
| # |
| # Returns None on failure. |
| def get_ports_by_switch(self, switch_id): |
| return self._get_multi_elements("port_id", "port", "switch_id", int(switch_id), "port_id") |
| |
| # More complex lookup: look up all the trunk ports on a switch by |
| # ID |
| # |
| # Returns None on failure. |
| def get_trunk_port_names_by_switch(self, switch_id): |
| return self._get_multi_elements2("name", "port", "switch_id", int(switch_id), "is_trunk", True, "port_id") |
| |
| # Simple lookup: look up a port by its name and its parent switch |
| # by ID, and return the ID of the port. |
| # |
| # Returns None on failure. |
| def get_port_by_switch_and_name(self, switch_id, name): |
| return self._get_element2("port_id", "port", "switch_id", int(switch_id), "name", name) |
| |
| # Simple lookup: look up a port by its external name and its |
| # parent switch by ID, and return the ID of the port. |
| # |
| # Returns None on failure. |
| def get_port_by_switch_and_number(self, switch_id, number): |
| return self._get_element2("port_id", "port", "switch_id", int(switch_id), "number", int(number)) |
| |
| # Simple lookup: look up a port by ID, and return the current VLAN |
| # id of that port. |
| # |
| # Returns None on failure. |
| def get_current_vlan_id_by_port(self, port_id): |
| return self._get_element("current_vlan_id", "port", "port_id", int(port_id)) |
| |
| # Simple lookup: look up a port by ID, and return the base VLAN |
| # id of that port. |
| # |
| # Returns None on failure. |
| def get_base_vlan_id_by_port(self, port_id): |
| return self._get_element("base_vlan_id", "port", "port_id", int(port_id)) |
| |
| # Simple lookup: look up a current VLAN by ID, and return the IDs |
| # of all the ports on that VLAN. |
| # |
| # Returns None on failure. |
| def get_ports_by_current_vlan(self, vlan_id): |
| return self._get_multi_elements("port_id", "port", "current_vlan_id", int(vlan_id), "port_id") |
| |
| # Simple lookup: look up a base VLAN by ID, and return the IDs |
| # of all the ports on that VLAN. |
| # |
| # Returns None on failure. |
| def get_ports_by_base_vlan(self, vlan_id): |
| return self._get_multi_elements("port_id", "port", "base_vlan_id", int(vlan_id), "port_id") |
| |
| # Simple lookup: look up a trunk by ID, and return the IDs of the |
| # ports on both ends of that trunk. |
| # |
| # Returns None on failure. |
| def get_ports_by_trunk(self, trunk_id): |
| return self._get_multi_elements("port_id", "port", "trunk_id", int(trunk_id), "port_id") |
| |
| # Simple lookup: look up a VLAN by ID, and return all the details |
| # of that VLAN. |
| # |
| # Returns None on failure. |
| def get_vlan_by_id(self, vlan_id): |
| return self._get_row("vlan", "vlan_id", int(vlan_id)) |
| |
| # Simple lookup: look up a VLAN by name, and return the ID of that |
| # VLAN. |
| # |
| # Returns None on failure. |
| def get_vlan_id_by_name(self, name): |
| return self._get_element("vlan_id", "vlan", "name", name) |
| |
| # Simple lookup: look up a VLAN by tag, and return the ID of that |
| # VLAN. |
| # |
| # Returns None on failure. |
| def get_vlan_id_by_tag(self, tag): |
| return self._get_element("vlan_id", "vlan", "tag", int(tag)) |
| |
| # Simple lookup: look up a VLAN by ID, and return the name of that |
| # VLAN. |
| # |
| # Returns None on failure. |
| def get_vlan_name_by_id(self, vlan_id): |
| return self._get_element("name", "vlan", "vlan_id", int(vlan_id)) |
| |
| # Simple lookup: look up a VLAN by ID, and return the tag of that |
| # VLAN. |
| # |
| # Returns None on failure. |
| def get_vlan_tag_by_id(self, vlan_id): |
| return self._get_element("tag", "vlan", "vlan_id", int(vlan_id)) |
| |
| # Simple lookup: look up a trunk by ID, and return all the details |
| # of that trunk. |
| # |
| # Returns None on failure. |
| def get_trunk_by_id(self, trunk_id): |
| return self._get_row("trunk", "trunk_id", int(trunk_id)) |
| |
| # Get the last-modified time for the database |
| def get_last_modified_time(self): |
| sql = "SELECT last_modified FROM state" |
| self.cursor.execute(sql) |
| return self.cursor.fetchone()[0] |
| |
| # Grab one row of a query on one column; useful as a quick wrapper |
| def _get_row(self, table, field, value): |
| |
| # We really want to use psycopg's type handling deal with the |
| # (potentially) user-supplied data in the value field, so we |
| # have to pass (sql,data) through to cursor.execute. However, |
| # we can't have psycopg do all the argument substitution here |
| # as it will quote all the params like the table name. That |
| # doesn't work. So, we substitute a "%s" for "%s" here so we |
| # keep it after python's own string substitution. |
| sql = "SELECT * FROM %s WHERE %s = %s" % (table, field, "%s") |
| |
| # Now, the next icky thing: we need to make sure that we're |
| # passing a dict so that psycopg2 can pick it apart properly |
| # for its own substitution code. We force this with the |
| # trailing comma here |
| data = (value, ) |
| self.dictcursor.execute(sql, data) |
| return self.dictcursor.fetchone() |
| |
| # (Un)Lock a port in the database. This can only be done through |
| # the admin interface, and will stop API users from modifying |
| # settings on the port. Use this to lock down ports that are used |
| # for PDUs and other core infrastructure |
| def set_port_is_locked(self, port_id, is_locked, lock_reason=""): |
| port = self.get_port_by_id(port_id) |
| if port is None: |
| raise NotFoundError("Port ID %d does not exist" % int(port_id)) |
| try: |
| sql = "UPDATE port SET is_locked=%s, lock_reason=%s WHERE port_id=%s RETURNING port_id" |
| data = (is_locked, lock_reason, port_id) |
| self.cursor.execute(sql, data) |
| port_id = self.cursor.fetchone()[0] |
| self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),)) |
| self.connection.commit() |
| except: |
| self.connection.rollback() |
| raise InputError("lock failed on Port ID %d" % int(port_id)) |
| return port_id |
| |
| # Set the mode of a port in the database. Valid values for mode |
| # are "trunk" and "access" |
| def set_port_mode(self, port_id, mode): |
| port = self.get_port_by_id(port_id) |
| if port is None: |
| raise NotFoundError("Port ID %d does not exist" % int(port_id)) |
| if mode == "access": |
| is_trunk = False |
| elif mode == "trunk": |
| is_trunk = True |
| else: |
| raise InputError("Port mode %s is not valid" % mode) |
| try: |
| sql = "UPDATE port SET is_trunk=%s WHERE port_id=%s RETURNING port_id" |
| data = (is_trunk, port_id) |
| self.cursor.execute(sql, data) |
| port_id = self.cursor.fetchone()[0] |
| self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),)) |
| self.connection.commit() |
| except: |
| self.connection.rollback() |
| raise |
| return port_id |
| |
| # Set the current vlan of a port in the database. The VLAN is |
| # passed by ID. |
| # |
| # Constraints: |
| # 1. The port must already exist |
| # 2. The port must not be a trunk port |
| # 3. The port must not be locked |
| # 1. The VLAN must already exist in the database |
| def set_current_vlan(self, port_id, vlan_id): |
| port = self.get_port_by_id(port_id) |
| if port is None: |
| raise NotFoundError("Port ID %d does not exist" % int(port_id)) |
| |
| if port['is_trunk'] or port['is_locked']: |
| raise CriticalError("The port is locked") |
| |
| vlan = self.get_vlan_by_id(vlan_id) |
| if vlan is None: |
| raise NotFoundError("VLAN ID %d does not exist" % int(vlan_id)) |
| |
| try: |
| sql = "UPDATE port SET current_vlan_id=%s WHERE port_id=%s RETURNING port_id" |
| data = (vlan_id, port_id) |
| self.cursor.execute(sql, data) |
| port_id = self.cursor.fetchone()[0] |
| self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),)) |
| self.connection.commit() |
| except: |
| self.connection.rollback() |
| raise |
| return port_id |
| |
| # Set the base vlan of a port in the database. The VLAN is |
| # passed by ID. |
| # |
| # Constraints: |
| # 1. The port must already exist |
| # 2. The port must not be a trunk port |
| # 3. The port must not be locked |
| # 4. The VLAN must already exist in the database |
| def set_base_vlan(self, port_id, vlan_id): |
| port = self.get_port_by_id(port_id) |
| if port is None: |
| raise NotFoundError("Port ID %d does not exist" % int(port_id)) |
| |
| if port['is_trunk'] or port['is_locked']: |
| raise CriticalError("The port is locked") |
| |
| vlan = self.get_vlan_by_id(vlan_id) |
| if vlan is None: |
| raise NotFoundError("VLAN ID %d does not exist" % int(vlan_id)) |
| if not vlan['is_base_vlan']: |
| raise InputError("VLAN ID %d is not a base VLAN" % int(vlan_id)) |
| |
| try: |
| sql = "UPDATE port SET base_vlan_id=%s WHERE port_id=%s RETURNING port_id" |
| data = (vlan_id, port_id) |
| self.cursor.execute(sql, data) |
| port_id = self.cursor.fetchone()[0] |
| self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),)) |
| self.connection.commit() |
| except: |
| self.connection.rollback() |
| raise |
| return port_id |
| |
| # Internal function: Attach a port to a trunk in the database. |
| # |
| # Constraints: |
| # 1. The port must already exist |
| # 2. The port must not be locked |
| def _set_port_trunk(self, port_id, trunk_id): |
| port = self.get_port_by_id(port_id) |
| if port is None: |
| raise NotFoundError("Port ID %d does not exist" % int(port_id)) |
| if port['is_locked']: |
| raise CriticalError("The port is locked") |
| try: |
| sql = "UPDATE port SET trunk_id=%s WHERE port_id=%s RETURNING port_id" |
| data = (int(trunk_id), int(port_id)) |
| self.cursor.execute(sql, data) |
| port_id = self.cursor.fetchone()[0] |
| self.cursor.execute("UPDATE state SET last_modified=%s", (datetime.datetime.now(),)) |
| self.connection.commit() |
| except: |
| self.connection.rollback() |
| raise |
| return port_id |
| |
| # Trivial helper function to return all the rows in a given table |
| def _dump_table(self, table, order): |
| result = [] |
| self.dictcursor.execute("SELECT * FROM %s ORDER by %s ASC" % (table, order)) |
| record = self.dictcursor.fetchone() |
| while record != None: |
| result.append(record) |
| record = self.dictcursor.fetchone() |
| return result |
| |
| def all_switches(self): |
| return self._dump_table("switch", "switch_id") |
| |
| def all_ports(self): |
| return self._dump_table("port", "port_id") |
| |
| def all_vlans(self): |
| return self._dump_table("vlan", "vlan_id") |
| |
| def all_trunks(self): |
| return self._dump_table("trunk", "trunk_id") |
| |
| if __name__ == '__main__': |
| db = VlanDB() |
| s = db.all_switches() |
| print 'The DB knows about %d switch(es)' % len(s) |
| print s |
| p = db.all_ports() |
| print 'The DB knows about %d port(s)' % len(p) |
| print p |
| v = db.all_vlans() |
| print 'The DB knows about %d vlan(s)' % len(v) |
| print v |
| t = db.all_trunks() |
| print 'The DB knows about %d trunks(s)' % len(t) |
| print t |
| |
| print 'First free VLAN tag is %d' % db.find_lowest_unused_vlan_tag() |