#! /usr/bin/python # Copyright 2014-2015 Linaro Limited # Authors: Dave Pigott , # Steve McIntyre # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation; either version 2 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, # MA 02110-1301, USA. import psycopg2 import psycopg2.extras import datetime, os, sys import logging TRUNK_ID_NONE = -1 if __name__ == '__main__': vlandpath = os.path.abspath(os.path.normpath(os.path.dirname(sys.argv[0]))) sys.path.insert(0, vlandpath) sys.path.insert(0, "%s/.." % vlandpath) from errors import CriticalError, InputError class VlanDB: def __init__(self, db_name="vland", username="vland"): try: self.connection = psycopg2.connect(database=db_name, user=username) self.cursor = self.connection.cursor(cursor_factory=psycopg2.extras.NamedTupleCursor) except Exception as e: logging.error("Failed to access database: %s", e) raise def __del__(self): self.cursor.close() self.connection.close() # Create a new switch in the database. Switches are really simple # devices - they're just containers for ports. # # Constraints: # Switches must be uniquely named def create_switch(self, name): switch_id = self.get_switch_id_by_name(name) if switch_id is not None: raise InputError("Switch name %s already exists" % name) try: sql = "INSERT INTO switch (name) VALUES (%s) RETURNING switch_id" data = (name, ) self.cursor.execute(sql, data) switch_id = self.cursor.fetchone()[0] self.connection.commit() except: self.connection.rollback() raise return switch_id # Create a new port in the database. Three of the fields are # created with default values (is_locked, is_trunk, trunk_id) # here, and should be updated separately if desired. For the # current_vlan_id and base_vlan_id fields, *BE CAREFUL* that you # have already looked up the correct VLAN_ID for each. This is # *NOT* the same as the VLAN tag (likely to be 1). You Have Been # Warned! # # Constraints: # 1. The switch referred to must already exist # 2. The VLANs mentioned here must already exist # 3. (Switch/name) must be unique # 4. (Switch/number) must be unique def create_port(self, switch_id, name, number, current_vlan_id, base_vlan_id): switch = self.get_switch_by_id(switch_id) if switch is None: raise InputError("Switch ID %d does not exist" % int(switch_id)) for vlan_id in (current_vlan_id, base_vlan_id): vlan = self.get_vlan_by_id(vlan_id) if vlan is None: raise InputError("VLAN ID %d does not exist" % int(vlan_id)) port_id = self.get_port_by_switch_and_name(switch_id, name) if port_id is not None: raise InputError("Already have a port %s on switch ID %d" % (name, int(switch_id))) port_id = self.get_port_by_switch_and_number(switch_id, int(number)) if port_id is not None: raise InputError("Already have a port %d on switch ID %d" % (int(number), int(switch_id))) try: sql = "INSERT INTO port (name, number, switch_id, is_locked, is_trunk, current_vlan_id, base_vlan_id, trunk_id) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) RETURNING port_id" data = (name, number, switch_id, False, False, current_vlan_id, base_vlan_id, TRUNK_ID_NONE) self.cursor.execute(sql, data) port_id = self.cursor.fetchone()[0] self.connection.commit() except: self.connection.rollback() raise return port_id # Create a new vlan in the database. We locally add a creation # timestamp, for debug purposes. If vlans seems to be sticking # around, we'll be able to see when they were created. # # Constraints: # Names and tags must be unique # Tags must be in the range 1-4095 (802.1q spec) # Names can be any free-form text, length 1-32 characters def create_vlan(self, name, tag, is_base_vlan): if int(tag) < 1 or int(tag) > 4095: raise InputError("VLAN tag %d is outside of the valid range (1-4095)" % int(tag)) if (len(name) < 1) or (len(name) > 32): raise InputError("VLAN name %s is invalid (must be 1-32 chars)" % name) vlan_id = self.get_vlan_id_by_name(name) if vlan_id is not None: raise InputError("VLAN name %s is already in use" % name) vlan_id = self.get_vlan_id_by_tag(tag) if vlan_id is not None: raise InputError("VLAN tag %d is already in use" % int(tag)) try: dt = datetime.datetime.now() sql = "INSERT INTO vlan (name, tag, is_base_vlan, creation_time) VALUES (%s, %s, %s, %s) RETURNING vlan_id" data = (name, tag, is_base_vlan, dt) self.cursor.execute(sql, data) vlan_id = self.cursor.fetchone()[0] self.connection.commit() except: self.connection.rollback() raise return vlan_id # Create a new trunk in the database, linking two ports. Trunks # are really simple objects for our use - they're just containers # for 2 ports. # # Constraints: # 1. Both ports listed must already exist. # 2. Both ports must be in trunk mode. # 3. Both must not be locked. # 4. Both must not already be in a trunk. def create_trunk(self, port_id1, port_id2): for port_id in (port_id1, port_id2): port = self.get_port_by_id(int(port_id)) if port is None: raise InputError("Port ID %d does not exist" % int(port_id)) if not port.is_trunk: raise InputError("Port ID %d is not in trunk mode" % int(port_id)) if port.is_locked: raise InputError("Port ID %d is locked" % int(port_id)) if port.trunk_id != TRUNK_ID_NONE: raise InputError("Port ID %d is already on trunk ID %d" % (int(port_id), int(port.trunk_id))) try: # Add the trunk itself dt = datetime.datetime.now() sql = "INSERT INTO trunk (creation_time) VALUES (%s) RETURNING trunk_id" data = (dt, ) self.cursor.execute(sql, data) trunk_id = self.cursor.fetchone()[0] self.connection.commit() # And update the ports for port_id in (port_id1, port_id2): self._set_port_trunk(port_id, trunk_id) except: self.delete_trunk(trunk_id) raise return trunk_id # Internal helper function def _delete_row(self, table, field, value): try: sql = "DELETE FROM %s WHERE %s = %s" % (table, field, '%s') data = (value,) self.cursor.execute(sql, data) self.connection.commit() except: self.connection.rollback() raise # Delete the specified switch # # Constraints: # 1. The switch must exist # 2. The switch may not be referenced by any ports - # delete them first! def delete_switch(self, switch_id): switch = self.get_switch_by_id(switch_id) if switch is None: raise InputError("Switch ID %d does not exist" % int(switch_id)) ports = self.get_ports_by_switch(switch_id) if ports is not None: raise InputError("Cannot delete switch ID %d when it still has %d ports" % (int(switch_id), len(ports))) self._delete_row("switch", "switch_id", switch_id) return switch_id # Delete the specified port # # Constraints: # 1. The port must exist # 2. The port must not be locked def delete_port(self, port_id): port = self.get_port_by_id(port_id) if port is None: raise InputError("Port ID %d does not exist" % int(port_id)) if port.is_locked: raise InputError("Cannot delete port ID %d as it is locked" % int(port_id)) self._delete_row("port", "port_id", port_id) return port_id # Delete the specified VLAN # # Constraints: # 1. The VLAN must exist # 2. The VLAN may not contain any ports - move or delete them first! def delete_vlan(self, vlan_id): vlan = self.get_vlan_by_id(vlan_id) if vlan is None: raise InputError("VLAN ID %d does not exist" % int(vlan_id)) ports = self.get_ports_by_current_vlan(vlan_id) if ports is not None: raise InputError("Cannot delete VLAN ID %d when it still has %d ports" % (int(vlan_id), len(ports))) ports = self.get_ports_by_base_vlan(vlan_id) if ports is not None: raise InputError("Cannot delete VLAN ID %d when it still has %d ports" % (int(vlan_id), len(ports))) self._delete_row("vlan", "vlan_id", vlan_id) return vlan_id # Delete the specified trunk # # Constraints: # 1. The trunk must exist # # Any ports attached will be detached (i.e. moved to trunk TRUNK_ID_NONE) def delete_trunk(self, trunk_id): trunk = self.get_trunk_by_id(trunk_id) if trunk is None: raise InputError("Trunk ID %d does not exist" % int(trunk_id)) ports = self.get_ports_by_trunk(trunk_id) for port_id in ports: self._set_port_trunk(port_id, TRUNK_ID_NONE) self._delete_row("trunk", "trunk_id", trunk_id) return trunk_id # Find the lowest unused VLAN tag and return it # # Constraints: # None def find_lowest_unused_vlan_tag(self): sql = "SELECT tag FROM vlan ORDER BY tag ASC" self.cursor.execute(sql,) # Walk through the list, looking for gaps last = 1 result = None for record in self.cursor: if (record[0] - last) > 1: result = last + 1 break last = record[0] if result is None: result = last + 1 if result > 4093: raise CriticalError("Can't find any VLAN tags remaining for allocation!") return result # Grab one column from one row of a query on one column; useful as # a quick wrapper def _get_element(self, select_field, table, compare_field, value): # We really want to use psycopg's type handling deal with the # (potentially) user-supplied data in the value field, so we # have to pass (sql,data) through to cursor.execute. However, # we can't have psycopg do all the argument substitution here # as it will quote all the params like the table name. That # doesn't work. So, we substitute a "%s" for "%s" here so we # keep it after python's own string substitution. sql = "SELECT %s FROM %s WHERE %s = %s" % (select_field, table, compare_field, "%s") # Now, the next icky thing: we need to make sure that we're # passing a dict so that psycopg2 can pick it apart properly # for its own substitution code. We force this with the # trailing comma here data = (value, ) self.cursor.execute(sql, data) if self.cursor.rowcount > 0: return self.cursor.fetchone()[0] else: return None # Grab one column from one row of a query on 2 columns; useful as # a quick wrapper def _get_element2(self, select_field, table, compare_field1, value1, compare_field2, value2): # We really want to use psycopg's type handling deal with the # (potentially) user-supplied data in the value field, so we # have to pass (sql,data) through to cursor.execute. However, # we can't have psycopg do all the argument substitution here # as it will quote all the params like the table name. That # doesn't work. So, we substitute a "%s" for "%s" here so we # keep it after python's own string substitution. sql = "SELECT %s FROM %s WHERE %s = %s AND %s = %s" % (select_field, table, compare_field1, "%s", compare_field2, "%s") data = (value1, value2) self.cursor.execute(sql, data) if self.cursor.rowcount > 0: return self.cursor.fetchone()[0] else: return None # Grab one column from multiple rows of a query; useful as a quick # wrapper def _get_multi_elements(self, select_field, table, compare_field, value): # We really want to use psycopg's type handling deal with the # (potentially) user-supplied data in the value field, so we # have to pass (sql,data) through to cursor.execute. However, # we can't have psycopg do all the argument substitution here # as it will quote all the params like the table name. That # doesn't work. So, we substitute a "%s" for "%s" here so we # keep it after python's own string substitution. sql = "SELECT %s FROM %s WHERE %s = %s" % (select_field, table, compare_field, "%s") # Now, the next icky thing: we need to make sure that we're # passing a dict so that psycopg2 can pick it apart properly # for its own substitution code. We force this with the # trailing comma here data = (value, ) self.cursor.execute(sql, data) if self.cursor.rowcount > 0: results = [] for record in self.cursor: results.append(record[0]) return results else: return None # Grab one column from multiple rows of a 2-part query; useful as # a wrapper def _get_multi_elements2(self, select_field, table, compare_field1, value1, compare_field2, value2): # We really want to use psycopg's type handling deal with the # (potentially) user-supplied data in the value field, so we # have to pass (sql,data) through to cursor.execute. However, # we can't have psycopg do all the argument substitution here # as it will quote all the params like the table name. That # doesn't work. So, we substitute a "%s" for "%s" here so we # keep it after python's own string substitution. sql = "SELECT %s FROM %s WHERE %s = %s AND %s = %s" % (select_field, table, compare_field1, "%s", compare_field2, "%s") data = (value1, value2) self.cursor.execute(sql, data) if self.cursor.rowcount > 0: results = [] for record in self.cursor: results.append(record[0]) return results else: return None # Simple lookup: look up a switch by ID, and return all the # details of that switch. # # Returns None on failure. def get_switch_by_id(self, switch_id): return self._get_row("switch", "switch_id", int(switch_id)) # Simple lookup: look up a switch by name, and return the ID of # that switch. # # Returns None on failure. def get_switch_id_by_name(self, name): return self._get_element("switch_id", "switch", "name", name) # Simple lookup: look up a switch by ID, and return the name of # that switch. # # Returns None on failure. def get_switch_name_by_id(self, switch_id): return self._get_element("name", "switch", "switch_id", int(switch_id)) # Simple lookup: look up a port by ID, and return all the details # of that port. # # Returns None on failure. def get_port_by_id(self, port_id): return self._get_row("port", "port_id", int(port_id)) # Simple lookup: look up a switch by ID, and return the IDs of all # the ports on that switch. # # Returns None on failure. def get_ports_by_switch(self, switch_id): return self._get_multi_elements("port_id", "port", "switch_id", int(switch_id)) # More complex lookup: look up all the trunk ports on a switch by # ID # # Returns None on failure. def get_trunk_port_names_by_switch(self, switch_id): return self._get_multi_elements2("name", "port", "switch_id", int(switch_id), "is_trunk", True) # Simple lookup: look up a port by its name and its parent switch # by ID, and return the ID of the port. # # Returns None on failure. def get_port_by_switch_and_name(self, switch_id, name): return self._get_element2("port_id", "port", "switch_id", int(switch_id), "name", name) # Simple lookup: look up a port by its external name and its # parent switch by ID, and return the ID of the port. # # Returns None on failure. def get_port_by_switch_and_number(self, switch_id, number): return self._get_element2("port_id", "port", "switch_id", int(switch_id), "number", int(number)) # Simple lookup: look up a port by ID, and return the current VLAN # id of that port. # # Returns None on failure. def get_current_vlan_id_by_port(self, port_id): return self._get_element("current_vlan_id", "port", "port_id", int(port_id)) # Simple lookup: look up a port by ID, and return the base VLAN # id of that port. # # Returns None on failure. def get_base_vlan_id_by_port(self, port_id): return self._get_element("base_vlan_id", "port", "port_id", int(port_id)) # Simple lookup: look up a current VLAN by ID, and return the IDs # of all the ports on that VLAN. # # Returns None on failure. def get_ports_by_current_vlan(self, vlan_id): return self._get_multi_elements("port_id", "port", "current_vlan_id", int(vlan_id)) # Simple lookup: look up a base VLAN by ID, and return the IDs # of all the ports on that VLAN. # # Returns None on failure. def get_ports_by_base_vlan(self, vlan_id): return self._get_multi_elements("port_id", "port", "base_vlan_id", int(vlan_id)) # Simple lookup: look up a trunk by ID, and return the IDs of the # ports on both ends of that trunk. # # Returns None on failure. def get_ports_by_trunk(self, trunk_id): return self._get_multi_elements("port_id", "port", "trunk_id", int(trunk_id)) # Simple lookup: look up a VLAN by ID, and return all the details # of that VLAN. # # Returns None on failure. def get_vlan_by_id(self, vlan_id): return self._get_row("vlan", "vlan_id", int(vlan_id)) # Simple lookup: look up a VLAN by name, and return the ID of that # VLAN. # # Returns None on failure. def get_vlan_id_by_name(self, name): return self._get_element("vlan_id", "vlan", "name", name) # Simple lookup: look up a VLAN by tag, and return the ID of that # VLAN. # # Returns None on failure. def get_vlan_id_by_tag(self, tag): return self._get_element("vlan_id", "vlan", "tag", int(tag)) # Simple lookup: look up a VLAN by ID, and return the name of that # VLAN. # # Returns None on failure. def get_vlan_name_by_id(self, vlan_id): return self._get_element("name", "vlan", "vlan_id", int(vlan_id)) # Simple lookup: look up a VLAN by ID, and return the tag of that # VLAN. # # Returns None on failure. def get_vlan_tag_by_id(self, vlan_id): return self._get_element("tag", "vlan", "vlan_id", int(vlan_id)) # Simple lookup: look up a trunk by ID, and return all the details # of that trunk. # # Returns None on failure. def get_trunk_by_id(self, trunk_id): return self._get_row("trunk", "trunk_id", int(trunk_id)) # Grab one row of a query on one column; useful as a quick wrapper def _get_row(self, table, field, value): # We really want to use psycopg's type handling deal with the # (potentially) user-supplied data in the value field, so we # have to pass (sql,data) through to cursor.execute. However, # we can't have psycopg do all the argument substitution here # as it will quote all the params like the table name. That # doesn't work. So, we substitute a "%s" for "%s" here so we # keep it after python's own string substitution. sql = "SELECT * FROM %s WHERE %s = %s" % (table, field, "%s") # Now, the next icky thing: we need to make sure that we're # passing a dict so that psycopg2 can pick it apart properly # for its own substitution code. We force this with the # trailing comma here data = (value, ) self.cursor.execute(sql, data) return self.cursor.fetchone() # (Un)Lock a port in the database. This can only be done through # the admin interface, and will stop API users from modifying # settings on the port. Use this to lock down ports that are used # for PDUs and other core infrastructure def set_port_is_locked(self, port_id, is_locked): port = self.get_port_by_id(port_id) if port is None: raise InputError("Port ID %d does not exist" % int(port_id)) try: sql = "UPDATE port SET is_locked=%s WHERE port_id=%s RETURNING port_id" data = (is_locked, port_id) self.cursor.execute(sql, data) port_id = self.cursor.fetchone()[0] self.connection.commit() except: self.connection.rollback() raise return port_id # Set the mode of a port in the database. Valid values for mode # are "trunk" and "access" def set_port_mode(self, port_id, mode): port = self.get_port_by_id(port_id) if port is None: raise InputError("Port ID %d does not exist" % int(port_id)) if mode == "access": is_trunk = False elif mode == "trunk": is_trunk = True else: raise InputError("Port mode %s is not valid" % mode) try: sql = "UPDATE port SET is_trunk=%s WHERE port_id=%s RETURNING port_id" data = (is_trunk, port_id) self.cursor.execute(sql, data) port_id = self.cursor.fetchone()[0] self.connection.commit() except: self.connection.rollback() raise return port_id # Set the current vlan of a port in the database. The VLAN is # passed by ID. # # Constraints: # 1. The port must already exist # 2. The port must not be a trunk port # 3. The port must not be locked # 1. The VLAN must already exist in the database def set_current_vlan(self, port_id, vlan_id): port = self.get_port_by_id(port_id) if port is None: raise InputError("Port ID %d does not exist" % int(port_id)) if port.is_trunk or port.is_locked: raise CriticalError("The port is locked") vlan = self.get_vlan_by_id(vlan_id) if vlan is None: raise InputError("VLAN ID %d does not exist" % int(vlan_id)) try: sql = "UPDATE port SET current_vlan_id=%s WHERE port_id=%s RETURNING port_id" data = (vlan_id, port_id) self.cursor.execute(sql, data) port_id = self.cursor.fetchone()[0] self.connection.commit() except: self.connection.rollback() raise return port_id # Set the base vlan of a port in the database. The VLAN is # passed by ID. # # Constraints: # 1. The port must already exist # 2. The port must not be a trunk port # 3. The port must not be locked # 4. The VLAN must already exist in the database def set_base_vlan(self, port_id, vlan_id): port = self.get_port_by_id(port_id) if port is None: raise InputError("Port ID %d does not exist" % int(port_id)) if port.is_trunk or port.is_locked: raise CriticalError("The port is locked") vlan = self.get_vlan_by_id(vlan_id) if vlan is None: raise InputError("VLAN ID %d does not exist" % int(vlan_id)) if not vlan.is_base_vlan: raise InputError("VLAN ID %d is not a base VLAN" % int(vlan_id)) try: sql = "UPDATE port SET base_vlan_id=%s WHERE port_id=%s RETURNING port_id" data = (vlan_id, port_id) self.cursor.execute(sql, data) port_id = self.cursor.fetchone()[0] self.connection.commit() except: self.connection.rollback() raise return port_id # Internal function: Attach a port to a trunk in the database. # # Constraints: # 1. The port must already exist # 2. The port must not be locked def _set_port_trunk(self, port_id, trunk_id): port = self.get_port_by_id(port_id) if port is None: raise InputError("Port ID %d does not exist" % int(port_id)) if port.is_locked: raise CriticalError("The port is locked") try: sql = "UPDATE port SET trunk_id=%s WHERE port_id=%s RETURNING port_id" data = (int(trunk_id), int(port_id)) self.cursor.execute(sql, data) port_id = self.cursor.fetchone()[0] self.connection.commit() except: self.connection.rollback() raise return port_id # Trivial helper function to return all the rows in a given table def _dump_table(self, table): result = [] self.cursor.execute("SELECT * FROM %s" % table) record = self.cursor.fetchone() while record != None: result.append(record) record = self.cursor.fetchone() return result def all_switches(self): return self._dump_table("switch") def all_ports(self): return self._dump_table("port") def all_vlans(self): return self._dump_table("vlan") def all_trunks(self): return self._dump_table("trunk") if __name__ == '__main__': db = VlanDB() s = db.all_switches() print 'The DB knows about %d switch(es)' % len(s) print s p = db.all_ports() print 'The DB knows about %d port(s)' % len(p) print p v = db.all_vlans() print 'The DB knows about %d vlan(s)' % len(v) print v t = db.all_trunks() print 'The DB knows about %d trunks(s)' % len(t) print t print 'First free VLAN tag is %d' % db.find_lowest_unused_vlan_tag()