| #! /usr/bin/python |
| |
| # Copyright 2014 Linaro Limited |
| # Author: Dave Pigott <dave.pigott@linaro.org> |
| # |
| # This program is free software; you can redistribute it and/or modify |
| # it under the terms of the GNU General Public License as published by |
| # the Free Software Foundation; either version 2 of the License, or |
| # (at your option) any later version. |
| # |
| # This program is distributed in the hope that it will be useful, |
| # but WITHOUT ANY WARRANTY; without even the implied warranty of |
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| # GNU General Public License for more details. |
| # |
| # You should have received a copy of the GNU General Public License |
| # along with this program; if not, write to the Free Software |
| # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, |
| # MA 02110-1301, USA. |
| |
| import psycopg2 |
| import psycopg2.extras |
| import datetime |
| from errors import CriticalError |
| |
| class VlanDB: |
| def __init__(self, db_name="vland", username="vland"): |
| try: |
| self.connection = psycopg2.connect(database=db_name, user=username) |
| self.cursor = self.connection.cursor(cursor_factory=psycopg2.extras.DictCursor) |
| except Exception as e: |
| print "Failed to access database: %s" % e |
| |
| def __del__(self): |
| self.cursor.close() |
| self.connection.close() |
| |
| # Create a new switch in the database |
| def create_switch(self, name): |
| try: |
| sql = "INSERT INTO switch (name) VALUES (%s) RETURNING switch_id" |
| data = name |
| self.cursor.execute(sql, data) |
| switch_id = self.cursor.fetchone()[0] |
| self.connection.commit() |
| except: |
| self.connection.rollback() |
| raise |
| return switch_id |
| |
| # Create a new port in the database. Two of the fields are created |
| # with default values (is_locked, is_trunk) here, and should be |
| # updated separately if desired. For the current_vlan_id and |
| # base_vlan_id fields, *BE CAREFUL* that you have already looked |
| # up the correct VLAN_ID for each. This is *NOT* the same as the |
| # VLAN tag (likely to be 1). |
| # You Have Been Warned! |
| def create_port(self, name, switch_id, current_vlan_id, base_vlan_id): |
| try: |
| sql = "INSERT INTO port (name, switch_id, is_locked, is_trunk, current_vlan_id, base_vlan_id) VALUES (%s, %s, %s, %s, %s, %s) RETURNING port_id" |
| data = (name, switch_id, |
| False, False, |
| current_vlan_id, base_vlan_id) |
| self.cursor.execute(sql, data) |
| port_id = self.cursor.fetchone()[0] |
| self.connection.commit() |
| except: |
| self.connection.rollback() |
| raise |
| return port_id |
| |
| def create_vlan(self, name, tag, is_base_vlan): |
| try: |
| dt = datetime.datetime.now() |
| sql = "INSERT INTO vlan (name, tag, is_base_vlan, creation_time) VALUES (%s, %s, %s, %s) RETURNING vlan_id" |
| data = (name, tag, is_base_vlan, dt) |
| self.cursor.execute(sql, data) |
| vlan_id = self.cursor.fetchone()[0] |
| self.connection.commit() |
| except: |
| self.connection.rollback() |
| raise |
| return vlan_id |
| |
| def _delete_row(self, table, field, value): |
| try: |
| sql = "DELETE FROM %s WHERE %s = %s" |
| data = (table, field, value) |
| self.cursor.execute(sql, data) |
| self.connection.commit() |
| except: |
| self.connection.rollback() |
| raise |
| |
| def delete_switch(self, switch_id): |
| self._delete_row("switch", "switch_id", switch_id) |
| |
| def delete_port(self, port_id): |
| self._delete_row("port", "port_id", port_id) |
| |
| def delete_vlan(self, vlan_id): |
| self._delete_row("vlan", "vlan_id", vlan_id) |
| |
| def _get_element(self, select_field, table, compare_field, value): |
| |
| # We really want to use psycopg's type handling deal with the |
| # (potentially) user-supplied data in the value field, so we |
| # have to pass (sql,data) through to cursor.execute. However, |
| # we can't have psycopg do all the argument substitution here |
| # as it will quote all the params like the table name. That |
| # doesn't work. So, we substitute a "%s" for "%s" here so we |
| # keep it after python's own string substitution. |
| sql = "SELECT %s FROM %s WHERE %s = %s" % (select_field, table, compare_field, "%s") |
| |
| # Now, the next icky thing: we need to make sure that we're |
| # passing a dict so that psycopg2 can pick it apart properly |
| # for its own substitution code. We force this with the |
| # trailing comma here |
| data = (value, ) |
| self.cursor.execute(sql, data) |
| |
| # Will raise an exception here if there are no rows that |
| # match. That's OK - the caller needs to deal with that. |
| return self.cursor.fetchone()[0] |
| |
| def get_switch_id(self, name): |
| return self._get_element("switch_id", "switch", "name", name) |
| |
| def get_port_id(self, name): |
| return self._get_element("port_id", "port", "name", name) |
| |
| def get_vlan_id_from_name(self, name): |
| return self._get_element("vlan_id", "vlan", "name", name) |
| |
| def get_vlan_id_from_tag(self, tag): |
| return self._get_element("vlan_id", "vlan", "tag", tag) |
| |
| def get_switch_name(self, switch_id): |
| return self._get_element("name", "switch", "switch_id", switch_id) |
| |
| def get_port_name(self, port_id): |
| return self._get_element("port_name", "port", "port_id", port_id) |
| |
| def get_vlan_name(self, vlan_id): |
| return self._get_element("vlan_name", "vlan", "vlan_id", vlan_id) |
| |
| def _get_row(self, table, field, value): |
| sql = "SELECT * FROM %s WHERE %s = %s" |
| data = (table, field, value) |
| self.cursor.execute(sql, data) |
| return self.cursor.fetchone() |
| |
| def get_switch(self, switch_id): |
| return self._get_row("switch", "switch_id", switch_id) |
| |
| def get_port(self, port_id): |
| return self._get_row("port", "port_id", port_id) |
| |
| def get_vlan(self, vlan_id): |
| return self._get_row("vlan", "vlan_id", vlan_id) |
| |
| def set_vlan(self, port_id, vlan_id): |
| port = self.get_port(port_id) |
| if port == None: |
| raise("Port %s does not exist" % port_id) |
| |
| if port["is_trunk"] or port["is_locked"]: |
| raise CriticalError("The port is locked") |
| |
| vlan = self.get_vlan(vlan_id) |
| if vlan == None: |
| raise CriticalError("VLAN %s does not exist" % vlan_id) |
| |
| try: |
| sql = "UPDATE port SET current_vlan_id=%s WHERE port_id=%s" |
| data = (vlan_id, port_id) |
| self.cursor.execute(sql, data) |
| except: |
| self.connection.rollback() |
| raise |
| |
| def restore_default_vlan(self, port_id): |
| port = self.get_port(port_id) |
| if port == None: |
| raise CriticalError("Port %s does not exist") |
| |
| if port["is_trunk"] or port["is_locked"]: |
| raise CriticalError("The port is locked") |
| |
| try: |
| sql = "UPDATE port SET current_vlan_id=base_vlan_id WHERE port_id=%d" |
| data = port_id |
| self.cursor.execute(sql, data) |
| except: |
| self.connection.rollback() |
| raise |
| |
| def _dump_table(self, table): |
| result = [] |
| self.cursor.execute("SELECT * FROM %s" % table) |
| record = self.cursor.fetchone() |
| while record != None: |
| result.append(record) |
| record = self.cursor.fetchone() |
| return result |
| |
| def all_switches(self): |
| return self._dump_table("switch") |
| |
| def all_ports(self): |
| return self._dump_table("port") |
| |
| def all_vlans(self): |
| return self._dump_table("vlan") |
| |