| #! /usr/bin/python |
| |
| # Copyright 2014 Linaro Limited |
| # Author: Dave Pigott <dave.pigott@linaro.org> |
| # |
| # This program is free software; you can redistribute it and/or modify |
| # it under the terms of the GNU General Public License as published by |
| # the Free Software Foundation; either version 2 of the License, or |
| # (at your option) any later version. |
| # |
| # This program is distributed in the hope that it will be useful, |
| # but WITHOUT ANY WARRANTY; without even the implied warranty of |
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| # GNU General Public License for more details. |
| # |
| # You should have received a copy of the GNU General Public License |
| # along with this program; if not, write to the Free Software |
| # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, |
| # MA 02110-1301, USA. |
| |
| import psycopg2 |
| import psycopg2.extras |
| import datetime |
| from errors import CriticalError, InputError |
| |
| class VlanDB: |
| def __init__(self, db_name="vland", username="vland"): |
| try: |
| self.connection = psycopg2.connect(database=db_name, user=username) |
| self.cursor = self.connection.cursor(cursor_factory=psycopg2.extras.NamedTupleCursor) |
| except Exception as e: |
| print "Failed to access database: %s" % e |
| |
| def __del__(self): |
| self.cursor.close() |
| self.connection.close() |
| |
| # Create a new switch in the database. Switches are really simple |
| # devices - they're just containers for ports. |
| # |
| # Constraints: |
| # Switches must be uniquely named |
| def create_switch(self, name): |
| |
| switch_id = self.get_switch_id_by_name(name) |
| if switch_id is not None: |
| raise InputError("Switch name %s already exists" % name) |
| |
| try: |
| sql = "INSERT INTO switch (name) VALUES (%s) RETURNING switch_id" |
| data = (name, ) |
| self.cursor.execute(sql, data) |
| switch_id = self.cursor.fetchone()[0] |
| self.connection.commit() |
| except: |
| self.connection.rollback() |
| raise |
| |
| return switch_id |
| |
| # Create a new port in the database. Two of the fields are created |
| # with default values (is_locked, is_trunk) here, and should be |
| # updated separately if desired. For the current_vlan_id and |
| # base_vlan_id fields, *BE CAREFUL* that you have already looked |
| # up the correct VLAN_ID for each. This is *NOT* the same as the |
| # VLAN tag (likely to be 1). |
| # You Have Been Warned! |
| # |
| # Constraints: |
| # 1. The switch referred to must already exist |
| # 2. The VLANs mentioned here must already exist |
| # 3. (Switch/name) must be unique |
| def create_port(self, switch_id, name, current_vlan_id, base_vlan_id): |
| |
| switch = self.get_switch_by_id(switch_id) |
| if switch is None: |
| raise InputError("Switch id %s does not exist" % switch_id) |
| |
| for vlan_id in (current_vlan_id, base_vlan_id): |
| vlan = self.get_vlan_by_id(vlan_id) |
| if vlan is None: |
| raise InputError("VLAN id %s does not exist" % vlan_id) |
| |
| port_id = self.get_port_by_switch_and_name(switch_id, name) |
| if port_id is not None: |
| raise InputError("Already have a port %s on switch_id %d" % (name, int(switch_id))) |
| |
| try: |
| sql = "INSERT INTO port (name, switch_id, is_locked, is_trunk, current_vlan_id, base_vlan_id) VALUES (%s, %s, %s, %s, %s, %s) RETURNING port_id" |
| data = (name, switch_id, |
| False, False, |
| current_vlan_id, base_vlan_id) |
| self.cursor.execute(sql, data) |
| port_id = self.cursor.fetchone()[0] |
| self.connection.commit() |
| except: |
| self.connection.rollback() |
| raise |
| |
| return port_id |
| |
| # Create a new vlan in the database. We locally add a creation |
| # timestamp, for debug purposes. If vlans seems to be sticking |
| # around, we'll be able to see when they were created. |
| # |
| # Constraints: |
| # Names and tags must be unique |
| def create_vlan(self, name, tag, is_base_vlan): |
| |
| vlan_id = self.get_vlan_id_by_name(name) |
| if vlan_id is not None: |
| raise InputError("VLAN name %s is already in use" % name) |
| |
| vlan_id = self.get_vlan_id_by_tag(tag) |
| if vlan_id is not None: |
| raise InputError("VLAN tag %d is already in use" % int(tag)) |
| |
| try: |
| dt = datetime.datetime.now() |
| sql = "INSERT INTO vlan (name, tag, is_base_vlan, creation_time) VALUES (%s, %s, %s, %s) RETURNING vlan_id" |
| data = (name, tag, is_base_vlan, dt) |
| self.cursor.execute(sql, data) |
| vlan_id = self.cursor.fetchone()[0] |
| self.connection.commit() |
| except: |
| self.connection.rollback() |
| raise |
| |
| return vlan_id |
| |
| def _delete_row(self, table, field, value): |
| try: |
| sql = "DELETE FROM %s WHERE %s = %s" % (table, field, '%s') |
| data = (value,) |
| self.cursor.execute(sql, data) |
| self.connection.commit() |
| except: |
| self.connection.rollback() |
| raise |
| |
| # Delete the specified switch |
| # |
| # Constraints: |
| # 1. The switch must exist |
| # 2. The switch may not be referenced by any ports - |
| # delete them first! |
| def delete_switch(self, switch_id): |
| switch = self.get_switch_by_id(switch_id) |
| if switch is None: |
| raise InputError("Switch ID %s does not exist" % switch_id) |
| ports = self.get_ports_by_switch(switch_id) |
| if ports is not None: |
| raise InputError("Cannot delete switch ID %s when it still has %d ports" % |
| (switch_id, len(ports))) |
| self._delete_row("switch", "switch_id", switch_id) |
| return switch_id |
| |
| # Delete the specified port |
| # |
| # Constraints: |
| # 1. The port must exist |
| # 2. The port must not be locked |
| def delete_port(self, port_id): |
| port = self.get_port_by_id(port_id) |
| if port is None: |
| raise InputError("Port ID %s does not exist" % port_id) |
| if port.is_locked: |
| raise InputError("Cannot delete port ID %s as it is locked" % port_id) |
| self._delete_row("port", "port_id", port_id) |
| return port_id |
| |
| # Delete the specified VLAN |
| # |
| # Constraints: |
| # 1. The VLAN |
| # 2. The VLAN may not contain any ports - move or delete them first! |
| def delete_vlan(self, vlan_id): |
| vlan = self.get_vlan_by_id(vlan_id) |
| if vlan is None: |
| raise InputError("VLAN ID %s does not exist" % vlan_id) |
| ports = self.get_ports_by_current_vlan(vlan_id) |
| if ports is not None: |
| raise InputError("Cannot delete VLAN ID %s when it still has %d ports" % |
| (vlan_id, len(ports))) |
| ports = self.get_ports_by_base_vlan(vlan_id) |
| if ports is not None: |
| raise InputError("Cannot delete VLAN ID %s when it still has %d ports" % |
| (vlan_id, len(ports))) |
| self._delete_row("vlan", "vlan_id", vlan_id) |
| return vlan_id |
| |
| # Grab one column from one row of a query on one column; useful as a quick wrapper |
| def _get_element(self, select_field, table, compare_field, value): |
| |
| # We really want to use psycopg's type handling deal with the |
| # (potentially) user-supplied data in the value field, so we |
| # have to pass (sql,data) through to cursor.execute. However, |
| # we can't have psycopg do all the argument substitution here |
| # as it will quote all the params like the table name. That |
| # doesn't work. So, we substitute a "%s" for "%s" here so we |
| # keep it after python's own string substitution. |
| sql = "SELECT %s FROM %s WHERE %s = %s" % (select_field, table, compare_field, "%s") |
| |
| # Now, the next icky thing: we need to make sure that we're |
| # passing a dict so that psycopg2 can pick it apart properly |
| # for its own substitution code. We force this with the |
| # trailing comma here |
| data = (value, ) |
| self.cursor.execute(sql, data) |
| |
| if self.cursor.rowcount > 0: |
| return self.cursor.fetchone()[0] |
| else: |
| return None |
| |
| # Grab one column from one row of a query on 2 columns; useful as a quick wrapper |
| def _get_element2(self, select_field, table, compare_field1, value1, compare_field2, value2): |
| |
| # We really want to use psycopg's type handling deal with the |
| # (potentially) user-supplied data in the value field, so we |
| # have to pass (sql,data) through to cursor.execute. However, |
| # we can't have psycopg do all the argument substitution here |
| # as it will quote all the params like the table name. That |
| # doesn't work. So, we substitute a "%s" for "%s" here so we |
| # keep it after python's own string substitution. |
| sql = "SELECT %s FROM %s WHERE %s = %s AND %s = %s" % (select_field, table, compare_field1, "%s", compare_field2, "%s") |
| |
| # Now, the next icky thing: we need to make sure that we're |
| # passing a dict so that psycopg2 can pick it apart properly |
| # for its own substitution code. We force this with the |
| # trailing comma here |
| data = (value1, value2) |
| self.cursor.execute(sql, data) |
| |
| if self.cursor.rowcount > 0: |
| return self.cursor.fetchone()[0] |
| else: |
| return None |
| |
| # Grab one column from multiple rows of a query; useful as a quick wrapper |
| def _get_multi_elements(self, select_field, table, compare_field, value): |
| |
| # We really want to use psycopg's type handling deal with the |
| # (potentially) user-supplied data in the value field, so we |
| # have to pass (sql,data) through to cursor.execute. However, |
| # we can't have psycopg do all the argument substitution here |
| # as it will quote all the params like the table name. That |
| # doesn't work. So, we substitute a "%s" for "%s" here so we |
| # keep it after python's own string substitution. |
| sql = "SELECT %s FROM %s WHERE %s = %s" % (select_field, table, compare_field, "%s") |
| |
| # Now, the next icky thing: we need to make sure that we're |
| # passing a dict so that psycopg2 can pick it apart properly |
| # for its own substitution code. We force this with the |
| # trailing comma here |
| data = (value, ) |
| self.cursor.execute(sql, data) |
| |
| if self.cursor.rowcount > 0: |
| results = [] |
| for record in self.cursor: |
| results.append(record[0]) |
| return results |
| else: |
| return None |
| |
| def get_switch_by_id(self, switch_id): |
| return self._get_row("switch", "switch_id", switch_id) |
| |
| def get_switch_id_by_name(self, name): |
| return self._get_element("switch_id", "switch", "name", name) |
| |
| def get_switch_name_by_id(self, switch_id): |
| return self._get_element("name", "switch", "switch_id", switch_id) |
| |
| def get_port_by_id(self, port_id): |
| return self._get_row("port", "port_id", port_id) |
| |
| def get_port_id_by_name(self, name): |
| return self._get_element("port_id", "port", "name", name) |
| |
| def get_port_name_by_id(self, port_id): |
| return self._get_element("port_name", "port", "port_id", port_id) |
| |
| def get_ports_by_switch(self, switch_id): |
| return self._get_multi_elements("port_id", "port", "switch_id", switch_id) |
| |
| def get_port_by_switch_and_name(self, switch_id, name): |
| return self._get_element2("port_id", "port", "switch_id", switch_id, "name", name) |
| |
| def get_ports_by_current_vlan(self, vlan_id): |
| return self._get_multi_elements("port_id", "port", "current_vlan_id", vlan_id) |
| |
| def get_ports_by_base_vlan(self, vlan_id): |
| return self._get_multi_elements("port_id", "port", "base_vlan_id", vlan_id) |
| |
| def get_vlan_by_id(self, vlan_id): |
| return self._get_row("vlan", "vlan_id", vlan_id) |
| |
| def get_vlan_id_by_name(self, name): |
| return self._get_element("vlan_id", "vlan", "name", name) |
| |
| def get_vlan_id_by_tag(self, tag): |
| return self._get_element("vlan_id", "vlan", "tag", tag) |
| |
| def get_vlan_name_by_id(self, vlan_id): |
| return self._get_element("vlan_name", "vlan", "vlan_id", vlan_id) |
| |
| def _get_row(self, table, field, value): |
| |
| # We really want to use psycopg's type handling deal with the |
| # (potentially) user-supplied data in the value field, so we |
| # have to pass (sql,data) through to cursor.execute. However, |
| # we can't have psycopg do all the argument substitution here |
| # as it will quote all the params like the table name. That |
| # doesn't work. So, we substitute a "%s" for "%s" here so we |
| # keep it after python's own string substitution. |
| sql = "SELECT * FROM %s WHERE %s = %s" % (table, field, "%s") |
| |
| # Now, the next icky thing: we need to make sure that we're |
| # passing a dict so that psycopg2 can pick it apart properly |
| # for its own substitution code. We force this with the |
| # trailing comma here |
| data = (value, ) |
| self.cursor.execute(sql, data) |
| return self.cursor.fetchone() |
| |
| # (Un)Lock a port in the database. This can only be done through |
| # the admin interface, and will stop API users from modifying |
| # settings on the port. Use this to lock down ports that are used |
| # for PDUs and other core infrastructure |
| def set_port_is_locked(self, port_id, is_locked): |
| port = self.get_port_by_id(port_id) |
| if port is None: |
| raise InputError("Port %s does not exist" % port_id) |
| try: |
| sql = "UPDATE port SET is_locked=%s WHERE port_id=%s" |
| data = (is_locked, port_id) |
| self.cursor.execute(sql, data) |
| port_id = self.cursor.fetchone()[0] |
| self.connection.commit() |
| except: |
| self.connection.rollback() |
| raise |
| return port_id |
| |
| def set_vlan(self, port_id, vlan_id): |
| port = self.get_port_by_id(port_id) |
| if port == None: |
| raise("Port %s does not exist" % port_id) |
| |
| if port["is_trunk"] or port["is_locked"]: |
| raise CriticalError("The port is locked") |
| |
| vlan = self.get_vlan_by_id(vlan_id) |
| if vlan == None: |
| raise CriticalError("VLAN %s does not exist" % vlan_id) |
| |
| try: |
| sql = "UPDATE port SET current_vlan_id=%s WHERE port_id=%s" |
| data = (vlan_id, port_id) |
| self.cursor.execute(sql, data) |
| except: |
| self.connection.rollback() |
| raise |
| |
| def restore_default_vlan(self, port_id): |
| port = self.get_port_by_id(port_id) |
| if port == None: |
| raise CriticalError("Port %s does not exist") |
| |
| if port["is_trunk"] or port["is_locked"]: |
| raise CriticalError("The port is locked") |
| |
| try: |
| sql = "UPDATE port SET current_vlan_id=base_vlan_id WHERE port_id=%d" |
| data = port_id |
| self.cursor.execute(sql, data) |
| except: |
| self.connection.rollback() |
| raise |
| |
| def _dump_table(self, table): |
| result = [] |
| self.cursor.execute("SELECT * FROM %s" % table) |
| record = self.cursor.fetchone() |
| while record != None: |
| result.append(record) |
| record = self.cursor.fetchone() |
| return result |
| |
| def all_switches(self): |
| return self._dump_table("switch") |
| |
| def all_ports(self): |
| return self._dump_table("port") |
| |
| def all_vlans(self): |
| return self._dump_table("vlan") |
| |