blob: 71e458d1ad1138d348b4f3973493fa596e003d15 [file] [log] [blame]
#! /usr/bin/python
# Copyright 2014-2015 Linaro Limited
# Authors: Dave Pigott <dave.pigott@linaro.org>,
# Steve McIntyre <steve.mcintyre@linaro.org>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
# MA 02110-1301, USA.
import psycopg2
import psycopg2.extras
import datetime, os, sys
import logging
TRUNK_ID_NONE = -1
if __name__ == '__main__':
vlandpath = os.path.abspath(os.path.normpath(os.path.dirname(sys.argv[0])))
sys.path.insert(0, vlandpath)
sys.path.insert(0, "%s/.." % vlandpath)
from errors import CriticalError, InputError
class VlanDB:
def __init__(self, db_name="vland", username="vland"):
try:
self.connection = psycopg2.connect(database=db_name, user=username)
self.cursor = self.connection.cursor(cursor_factory=psycopg2.extras.NamedTupleCursor)
except Exception as e:
logging.error("Failed to access database: %s", e)
raise
def __del__(self):
self.cursor.close()
self.connection.close()
# Create a new switch in the database. Switches are really simple
# devices - they're just containers for ports.
#
# Constraints:
# Switches must be uniquely named
def create_switch(self, name):
switch_id = self.get_switch_id_by_name(name)
if switch_id is not None:
raise InputError("Switch name %s already exists" % name)
try:
sql = "INSERT INTO switch (name) VALUES (%s) RETURNING switch_id"
data = (name, )
self.cursor.execute(sql, data)
switch_id = self.cursor.fetchone()[0]
self.connection.commit()
except:
self.connection.rollback()
raise
return switch_id
# Create a new port in the database. Three of the fields are
# created with default values (is_locked, is_trunk, trunk_id)
# here, and should be updated separately if desired. For the
# current_vlan_id and base_vlan_id fields, *BE CAREFUL* that you
# have already looked up the correct VLAN_ID for each. This is
# *NOT* the same as the VLAN tag (likely to be 1). You Have Been
# Warned!
#
# Constraints:
# 1. The switch referred to must already exist
# 2. The VLANs mentioned here must already exist
# 3. (Switch/name) must be unique
# 4. (Switch/number) must be unique
def create_port(self, switch_id, name, number, current_vlan_id, base_vlan_id):
switch = self.get_switch_by_id(switch_id)
if switch is None:
raise InputError("Switch ID %d does not exist" % int(switch_id))
for vlan_id in (current_vlan_id, base_vlan_id):
vlan = self.get_vlan_by_id(vlan_id)
if vlan is None:
raise InputError("VLAN ID %d does not exist" % int(vlan_id))
port_id = self.get_port_by_switch_and_name(switch_id, name)
if port_id is not None:
raise InputError("Already have a port %s on switch ID %d" % (name, int(switch_id)))
port_id = self.get_port_by_switch_and_number(switch_id, int(number))
if port_id is not None:
raise InputError("Already have a port %d on switch ID %d" % (int(number), int(switch_id)))
try:
sql = "INSERT INTO port (name, number, switch_id, is_locked, is_trunk, current_vlan_id, base_vlan_id, trunk_id) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) RETURNING port_id"
data = (name, number, switch_id,
False, False,
current_vlan_id, base_vlan_id, TRUNK_ID_NONE)
self.cursor.execute(sql, data)
port_id = self.cursor.fetchone()[0]
self.connection.commit()
except:
self.connection.rollback()
raise
return port_id
# Create a new vlan in the database. We locally add a creation
# timestamp, for debug purposes. If vlans seems to be sticking
# around, we'll be able to see when they were created.
#
# Constraints:
# Names and tags must be unique
# Tags must be in the range 1-4095 (802.1q spec)
# Names can be any free-form text, length 1-32 characters
def create_vlan(self, name, tag, is_base_vlan):
if int(tag) < 1 or int(tag) > 4095:
raise InputError("VLAN tag %d is outside of the valid range (1-4095)" % int(tag))
if (len(name) < 1) or (len(name) > 32):
raise InputError("VLAN name %s is invalid (must be 1-32 chars)" % name)
vlan_id = self.get_vlan_id_by_name(name)
if vlan_id is not None:
raise InputError("VLAN name %s is already in use" % name)
vlan_id = self.get_vlan_id_by_tag(tag)
if vlan_id is not None:
raise InputError("VLAN tag %d is already in use" % int(tag))
try:
dt = datetime.datetime.now()
sql = "INSERT INTO vlan (name, tag, is_base_vlan, creation_time) VALUES (%s, %s, %s, %s) RETURNING vlan_id"
data = (name, tag, is_base_vlan, dt)
self.cursor.execute(sql, data)
vlan_id = self.cursor.fetchone()[0]
self.connection.commit()
except:
self.connection.rollback()
raise
return vlan_id
# Create a new trunk in the database, linking two ports. Trunks
# are really simple objects for our use - they're just containers
# for 2 ports.
#
# Constraints:
# 1. Both ports listed must already exist.
# 2. Both ports must be in trunk mode.
# 3. Both must not be locked.
# 4. Both must not already be in a trunk.
def create_trunk(self, port_id1, port_id2):
for port_id in (port_id1, port_id2):
port = self.get_port_by_id(int(port_id))
if port is None:
raise InputError("Port ID %d does not exist" % int(port_id))
if not port.is_trunk:
raise InputError("Port ID %d is not in trunk mode" % int(port_id))
if port.is_locked:
raise InputError("Port ID %d is locked" % int(port_id))
if port.trunk_id != TRUNK_ID_NONE:
raise InputError("Port ID %d is already on trunk ID %d" % (int(port_id), int(port.trunk_id)))
try:
# Add the trunk itself
dt = datetime.datetime.now()
sql = "INSERT INTO trunk (creation_time) VALUES (%s) RETURNING trunk_id"
data = (dt, )
self.cursor.execute(sql, data)
trunk_id = self.cursor.fetchone()[0]
self.connection.commit()
# And update the ports
for port_id in (port_id1, port_id2):
self._set_port_trunk(port_id, trunk_id)
except:
self.delete_trunk(trunk_id)
raise
return trunk_id
# Internal helper function
def _delete_row(self, table, field, value):
try:
sql = "DELETE FROM %s WHERE %s = %s" % (table, field, '%s')
data = (value,)
self.cursor.execute(sql, data)
self.connection.commit()
except:
self.connection.rollback()
raise
# Delete the specified switch
#
# Constraints:
# 1. The switch must exist
# 2. The switch may not be referenced by any ports -
# delete them first!
def delete_switch(self, switch_id):
switch = self.get_switch_by_id(switch_id)
if switch is None:
raise InputError("Switch ID %d does not exist" % int(switch_id))
ports = self.get_ports_by_switch(switch_id)
if ports is not None:
raise InputError("Cannot delete switch ID %d when it still has %d ports" %
(int(switch_id), len(ports)))
self._delete_row("switch", "switch_id", switch_id)
return switch_id
# Delete the specified port
#
# Constraints:
# 1. The port must exist
# 2. The port must not be locked
def delete_port(self, port_id):
port = self.get_port_by_id(port_id)
if port is None:
raise InputError("Port ID %d does not exist" % int(port_id))
if port.is_locked:
raise InputError("Cannot delete port ID %d as it is locked" % int(port_id))
self._delete_row("port", "port_id", port_id)
return port_id
# Delete the specified VLAN
#
# Constraints:
# 1. The VLAN must exist
# 2. The VLAN may not contain any ports - move or delete them first!
def delete_vlan(self, vlan_id):
vlan = self.get_vlan_by_id(vlan_id)
if vlan is None:
raise InputError("VLAN ID %d does not exist" % int(vlan_id))
ports = self.get_ports_by_current_vlan(vlan_id)
if ports is not None:
raise InputError("Cannot delete VLAN ID %d when it still has %d ports" %
(int(vlan_id), len(ports)))
ports = self.get_ports_by_base_vlan(vlan_id)
if ports is not None:
raise InputError("Cannot delete VLAN ID %d when it still has %d ports" %
(int(vlan_id), len(ports)))
self._delete_row("vlan", "vlan_id", vlan_id)
return vlan_id
# Delete the specified trunk
#
# Constraints:
# 1. The trunk must exist
#
# Any ports attached will be detached (i.e. moved to trunk TRUNK_ID_NONE)
def delete_trunk(self, trunk_id):
trunk = self.get_trunk_by_id(trunk_id)
if trunk is None:
raise InputError("Trunk ID %d does not exist" % int(trunk_id))
ports = self.get_ports_by_trunk(trunk_id)
for port_id in ports:
self._set_port_trunk(port_id, TRUNK_ID_NONE)
self._delete_row("trunk", "trunk_id", trunk_id)
return trunk_id
# Find the lowest unused VLAN tag and return it
#
# Constraints:
# None
def find_lowest_unused_vlan_tag(self):
sql = "SELECT tag FROM vlan ORDER BY tag ASC"
self.cursor.execute(sql,)
# Walk through the list, looking for gaps
last = 1
result = None
for record in self.cursor:
if (record[0] - last) > 1:
result = last + 1
break
last = record[0]
if result is None:
result = last + 1
if result > 4093:
raise CriticalError("Can't find any VLAN tags remaining for allocation!")
return result
# Grab one column from one row of a query on one column; useful as
# a quick wrapper
def _get_element(self, select_field, table, compare_field, value):
# We really want to use psycopg's type handling deal with the
# (potentially) user-supplied data in the value field, so we
# have to pass (sql,data) through to cursor.execute. However,
# we can't have psycopg do all the argument substitution here
# as it will quote all the params like the table name. That
# doesn't work. So, we substitute a "%s" for "%s" here so we
# keep it after python's own string substitution.
sql = "SELECT %s FROM %s WHERE %s = %s" % (select_field, table, compare_field, "%s")
# Now, the next icky thing: we need to make sure that we're
# passing a dict so that psycopg2 can pick it apart properly
# for its own substitution code. We force this with the
# trailing comma here
data = (value, )
self.cursor.execute(sql, data)
if self.cursor.rowcount > 0:
return self.cursor.fetchone()[0]
else:
return None
# Grab one column from one row of a query on 2 columns; useful as
# a quick wrapper
def _get_element2(self, select_field, table, compare_field1, value1, compare_field2, value2):
# We really want to use psycopg's type handling deal with the
# (potentially) user-supplied data in the value field, so we
# have to pass (sql,data) through to cursor.execute. However,
# we can't have psycopg do all the argument substitution here
# as it will quote all the params like the table name. That
# doesn't work. So, we substitute a "%s" for "%s" here so we
# keep it after python's own string substitution.
sql = "SELECT %s FROM %s WHERE %s = %s AND %s = %s" % (select_field, table, compare_field1, "%s", compare_field2, "%s")
data = (value1, value2)
self.cursor.execute(sql, data)
if self.cursor.rowcount > 0:
return self.cursor.fetchone()[0]
else:
return None
# Grab one column from multiple rows of a query; useful as a quick
# wrapper
def _get_multi_elements(self, select_field, table, compare_field, value, sort_field):
# We really want to use psycopg's type handling deal with the
# (potentially) user-supplied data in the value field, so we
# have to pass (sql,data) through to cursor.execute. However,
# we can't have psycopg do all the argument substitution here
# as it will quote all the params like the table name. That
# doesn't work. So, we substitute a "%s" for "%s" here so we
# keep it after python's own string substitution.
sql = "SELECT %s FROM %s WHERE %s = %s ORDER BY %s ASC" % (select_field, table, compare_field, "%s", sort_field)
# Now, the next icky thing: we need to make sure that we're
# passing a dict so that psycopg2 can pick it apart properly
# for its own substitution code. We force this with the
# trailing comma here
data = (value, )
self.cursor.execute(sql, data)
if self.cursor.rowcount > 0:
results = []
for record in self.cursor:
results.append(record[0])
return results
else:
return None
# Grab one column from multiple rows of a 2-part query; useful as
# a wrapper
def _get_multi_elements2(self, select_field, table, compare_field1, value1, compare_field2, value2, sort_field):
# We really want to use psycopg's type handling deal with the
# (potentially) user-supplied data in the value field, so we
# have to pass (sql,data) through to cursor.execute. However,
# we can't have psycopg do all the argument substitution here
# as it will quote all the params like the table name. That
# doesn't work. So, we substitute a "%s" for "%s" here so we
# keep it after python's own string substitution.
sql = "SELECT %s FROM %s WHERE %s = %s AND %s = %s ORDER by %s ASC" % (select_field, table, compare_field1, "%s", compare_field2, "%s", sort_field)
data = (value1, value2)
self.cursor.execute(sql, data)
if self.cursor.rowcount > 0:
results = []
for record in self.cursor:
results.append(record[0])
return results
else:
return None
# Simple lookup: look up a switch by ID, and return all the
# details of that switch.
#
# Returns None on failure.
def get_switch_by_id(self, switch_id):
return self._get_row("switch", "switch_id", int(switch_id))
# Simple lookup: look up a switch by name, and return the ID of
# that switch.
#
# Returns None on failure.
def get_switch_id_by_name(self, name):
return self._get_element("switch_id", "switch", "name", name)
# Simple lookup: look up a switch by ID, and return the name of
# that switch.
#
# Returns None on failure.
def get_switch_name_by_id(self, switch_id):
return self._get_element("name", "switch", "switch_id", int(switch_id))
# Simple lookup: look up a port by ID, and return all the details
# of that port.
#
# Returns None on failure.
def get_port_by_id(self, port_id):
return self._get_row("port", "port_id", int(port_id))
# Simple lookup: look up a switch by ID, and return the IDs of all
# the ports on that switch.
#
# Returns None on failure.
def get_ports_by_switch(self, switch_id):
return self._get_multi_elements("port_id", "port", "switch_id", int(switch_id), "port_id")
# More complex lookup: look up all the trunk ports on a switch by
# ID
#
# Returns None on failure.
def get_trunk_port_names_by_switch(self, switch_id):
return self._get_multi_elements2("name", "port", "switch_id", int(switch_id), "is_trunk", True, "port_id")
# Simple lookup: look up a port by its name and its parent switch
# by ID, and return the ID of the port.
#
# Returns None on failure.
def get_port_by_switch_and_name(self, switch_id, name):
return self._get_element2("port_id", "port", "switch_id", int(switch_id), "name", name)
# Simple lookup: look up a port by its external name and its
# parent switch by ID, and return the ID of the port.
#
# Returns None on failure.
def get_port_by_switch_and_number(self, switch_id, number):
return self._get_element2("port_id", "port", "switch_id", int(switch_id), "number", int(number))
# Simple lookup: look up a port by ID, and return the current VLAN
# id of that port.
#
# Returns None on failure.
def get_current_vlan_id_by_port(self, port_id):
return self._get_element("current_vlan_id", "port", "port_id", int(port_id))
# Simple lookup: look up a port by ID, and return the base VLAN
# id of that port.
#
# Returns None on failure.
def get_base_vlan_id_by_port(self, port_id):
return self._get_element("base_vlan_id", "port", "port_id", int(port_id))
# Simple lookup: look up a current VLAN by ID, and return the IDs
# of all the ports on that VLAN.
#
# Returns None on failure.
def get_ports_by_current_vlan(self, vlan_id):
return self._get_multi_elements("port_id", "port", "current_vlan_id", int(vlan_id), "port_id")
# Simple lookup: look up a base VLAN by ID, and return the IDs
# of all the ports on that VLAN.
#
# Returns None on failure.
def get_ports_by_base_vlan(self, vlan_id):
return self._get_multi_elements("port_id", "port", "base_vlan_id", int(vlan_id), "port_id")
# Simple lookup: look up a trunk by ID, and return the IDs of the
# ports on both ends of that trunk.
#
# Returns None on failure.
def get_ports_by_trunk(self, trunk_id):
return self._get_multi_elements("port_id", "port", "trunk_id", int(trunk_id), "port_id")
# Simple lookup: look up a VLAN by ID, and return all the details
# of that VLAN.
#
# Returns None on failure.
def get_vlan_by_id(self, vlan_id):
return self._get_row("vlan", "vlan_id", int(vlan_id))
# Simple lookup: look up a VLAN by name, and return the ID of that
# VLAN.
#
# Returns None on failure.
def get_vlan_id_by_name(self, name):
return self._get_element("vlan_id", "vlan", "name", name)
# Simple lookup: look up a VLAN by tag, and return the ID of that
# VLAN.
#
# Returns None on failure.
def get_vlan_id_by_tag(self, tag):
return self._get_element("vlan_id", "vlan", "tag", int(tag))
# Simple lookup: look up a VLAN by ID, and return the name of that
# VLAN.
#
# Returns None on failure.
def get_vlan_name_by_id(self, vlan_id):
return self._get_element("name", "vlan", "vlan_id", int(vlan_id))
# Simple lookup: look up a VLAN by ID, and return the tag of that
# VLAN.
#
# Returns None on failure.
def get_vlan_tag_by_id(self, vlan_id):
return self._get_element("tag", "vlan", "vlan_id", int(vlan_id))
# Simple lookup: look up a trunk by ID, and return all the details
# of that trunk.
#
# Returns None on failure.
def get_trunk_by_id(self, trunk_id):
return self._get_row("trunk", "trunk_id", int(trunk_id))
# Grab one row of a query on one column; useful as a quick wrapper
def _get_row(self, table, field, value):
# We really want to use psycopg's type handling deal with the
# (potentially) user-supplied data in the value field, so we
# have to pass (sql,data) through to cursor.execute. However,
# we can't have psycopg do all the argument substitution here
# as it will quote all the params like the table name. That
# doesn't work. So, we substitute a "%s" for "%s" here so we
# keep it after python's own string substitution.
sql = "SELECT * FROM %s WHERE %s = %s" % (table, field, "%s")
# Now, the next icky thing: we need to make sure that we're
# passing a dict so that psycopg2 can pick it apart properly
# for its own substitution code. We force this with the
# trailing comma here
data = (value, )
self.cursor.execute(sql, data)
return self.cursor.fetchone()
# (Un)Lock a port in the database. This can only be done through
# the admin interface, and will stop API users from modifying
# settings on the port. Use this to lock down ports that are used
# for PDUs and other core infrastructure
def set_port_is_locked(self, port_id, is_locked):
port = self.get_port_by_id(port_id)
if port is None:
raise InputError("Port ID %d does not exist" % int(port_id))
try:
sql = "UPDATE port SET is_locked=%s WHERE port_id=%s RETURNING port_id"
data = (is_locked, port_id)
self.cursor.execute(sql, data)
port_id = self.cursor.fetchone()[0]
self.connection.commit()
except:
self.connection.rollback()
raise
return port_id
# Set the mode of a port in the database. Valid values for mode
# are "trunk" and "access"
def set_port_mode(self, port_id, mode):
port = self.get_port_by_id(port_id)
if port is None:
raise InputError("Port ID %d does not exist" % int(port_id))
if mode == "access":
is_trunk = False
elif mode == "trunk":
is_trunk = True
else:
raise InputError("Port mode %s is not valid" % mode)
try:
sql = "UPDATE port SET is_trunk=%s WHERE port_id=%s RETURNING port_id"
data = (is_trunk, port_id)
self.cursor.execute(sql, data)
port_id = self.cursor.fetchone()[0]
self.connection.commit()
except:
self.connection.rollback()
raise
return port_id
# Set the current vlan of a port in the database. The VLAN is
# passed by ID.
#
# Constraints:
# 1. The port must already exist
# 2. The port must not be a trunk port
# 3. The port must not be locked
# 1. The VLAN must already exist in the database
def set_current_vlan(self, port_id, vlan_id):
port = self.get_port_by_id(port_id)
if port is None:
raise InputError("Port ID %d does not exist" % int(port_id))
if port.is_trunk or port.is_locked:
raise CriticalError("The port is locked")
vlan = self.get_vlan_by_id(vlan_id)
if vlan is None:
raise InputError("VLAN ID %d does not exist" % int(vlan_id))
try:
sql = "UPDATE port SET current_vlan_id=%s WHERE port_id=%s RETURNING port_id"
data = (vlan_id, port_id)
self.cursor.execute(sql, data)
port_id = self.cursor.fetchone()[0]
self.connection.commit()
except:
self.connection.rollback()
raise
return port_id
# Set the base vlan of a port in the database. The VLAN is
# passed by ID.
#
# Constraints:
# 1. The port must already exist
# 2. The port must not be a trunk port
# 3. The port must not be locked
# 4. The VLAN must already exist in the database
def set_base_vlan(self, port_id, vlan_id):
port = self.get_port_by_id(port_id)
if port is None:
raise InputError("Port ID %d does not exist" % int(port_id))
if port.is_trunk or port.is_locked:
raise CriticalError("The port is locked")
vlan = self.get_vlan_by_id(vlan_id)
if vlan is None:
raise InputError("VLAN ID %d does not exist" % int(vlan_id))
if not vlan.is_base_vlan:
raise InputError("VLAN ID %d is not a base VLAN" % int(vlan_id))
try:
sql = "UPDATE port SET base_vlan_id=%s WHERE port_id=%s RETURNING port_id"
data = (vlan_id, port_id)
self.cursor.execute(sql, data)
port_id = self.cursor.fetchone()[0]
self.connection.commit()
except:
self.connection.rollback()
raise
return port_id
# Internal function: Attach a port to a trunk in the database.
#
# Constraints:
# 1. The port must already exist
# 2. The port must not be locked
def _set_port_trunk(self, port_id, trunk_id):
port = self.get_port_by_id(port_id)
if port is None:
raise InputError("Port ID %d does not exist" % int(port_id))
if port.is_locked:
raise CriticalError("The port is locked")
try:
sql = "UPDATE port SET trunk_id=%s WHERE port_id=%s RETURNING port_id"
data = (int(trunk_id), int(port_id))
self.cursor.execute(sql, data)
port_id = self.cursor.fetchone()[0]
self.connection.commit()
except:
self.connection.rollback()
raise
return port_id
# Trivial helper function to return all the rows in a given table
def _dump_table(self, table, order):
result = []
self.cursor.execute("SELECT * FROM %s ORDER by %s ASC" % (table, order))
record = self.cursor.fetchone()
while record != None:
result.append(record)
record = self.cursor.fetchone()
return result
def all_switches(self):
return self._dump_table("switch", "switch_id")
def all_ports(self):
return self._dump_table("port", "port_id")
def all_vlans(self):
return self._dump_table("vlan", "vlan_id")
def all_trunks(self):
return self._dump_table("trunk", "trunk_id")
if __name__ == '__main__':
db = VlanDB()
s = db.all_switches()
print 'The DB knows about %d switch(es)' % len(s)
print s
p = db.all_ports()
print 'The DB knows about %d port(s)' % len(p)
print p
v = db.all_vlans()
print 'The DB knows about %d vlan(s)' % len(v)
print v
t = db.all_trunks()
print 'The DB knows about %d trunks(s)' % len(t)
print t
print 'First free VLAN tag is %d' % db.find_lowest_unused_vlan_tag()