blob: f302c603db96765c4c91d02dbdefdcae89ba731d [file] [log] [blame]
#! /usr/bin/python
# Copyright 2014 Linaro Limited
# Author: Dave Pigott <dave.pigott@linaro.org>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
# MA 02110-1301, USA.
import psycopg2
import psycopg2.extras
import datetime
from errors import CriticalError
class VlanDB:
def __init__(self, db_name="vland", username="vland"):
try:
self.connection = psycopg2.connect(database=db_name, user=username)
self.cursor = self.connection.cursor(cursor_factory=psycopg2.extras.DictCursor)
except Exception as e:
print "Failed to access database: %s" % e
def __del__(self):
self.cursor.close()
self.connection.close()
# Create a new switch in the database
def create_switch(self, name):
try:
sql = "INSERT INTO switch (name) VALUES (%s) RETURNING switch_id"
data = name
self.cursor.execute(sql, data)
switch_id = self.cursor.fetchone()[0]
self.connection.commit()
except:
self.connection.rollback()
raise
return switch_id
# Create a new port in the database. Two of the fields are created
# with default values (is_locked, is_trunk) here, and should be
# updated separately if desired. For the current_vlan_id and
# base_vlan_id fields, *BE CAREFUL* that you have already looked
# up the correct VLAN_ID for each. This is *NOT* the same as the
# VLAN tag (likely to be 1).
# You Have Been Warned!
def create_port(self, name, switch_id, current_vlan_id, base_vlan_id):
try:
sql = "INSERT INTO port (name, switch_id, is_locked, is_trunk, current_vlan_id, base_vlan_id) VALUES (%s, %s, %s, %s, %s, %s) RETURNING port_id"
data = (name, switch_id,
False, False,
current_vlan_id, base_vlan_id)
self.cursor.execute(sql, data)
port_id = self.cursor.fetchone()[0]
self.connection.commit()
except:
self.connection.rollback()
raise
return port_id
def create_vlan(self, name, tag, is_base_vlan):
try:
dt = datetime.datetime.now()
sql = "INSERT INTO vlan (name, tag, is_base_vlan, creation_time) VALUES (%s, %s, %s, %s) RETURNING vlan_id"
data = (name, tag, is_base_vlan, dt)
self.cursor.execute(sql, data)
vlan_id = self.cursor.fetchone()[0]
self.connection.commit()
except:
self.connection.rollback()
raise
return vlan_id
def _delete_row(self, table, field, value):
try:
sql = "DELETE FROM %s WHERE %s = %s"
data = (table, field, value)
self.cursor.execute(sql, data)
self.connection.commit()
except:
self.connection.rollback()
raise
def delete_switch(self, switch_id):
self._delete_row("switch", "switch_id", switch_id)
def delete_port(self, port_id):
self._delete_row("port", "port_id", port_id)
def delete_vlan(self, vlan_id):
self._delete_row("vlan", "vlan_id", vlan_id)
def _get_element(self, select_field, table, compare_field, value):
# We really want to use psycopg's type handling deal with the
# (potentially) user-supplied data in the value field, so we
# have to pass (sql,data) through to cursor.execute. However,
# we can't have psycopg do all the argument substitution here
# as it will quote all the params like the table name. That
# doesn't work. So, we substitute a "%s" for "%s" here so we
# keep it after python's own string substitution.
sql = "SELECT %s FROM %s WHERE %s = %s" % (select_field, table, compare_field, "%s")
# Now, the next icky thing: we need to make sure that we're
# passing a dict so that psycopg2 can pick it apart properly
# for its own substitution code. We force this with the
# trailing comma here
data = (value, )
self.cursor.execute(sql, data)
# Will raise an exception here if there are no rows that
# match. That's OK - the caller needs to deal with that.
return self.cursor.fetchone()[0]
def get_switch_id(self, name):
return self._get_element("switch_id", "switch", "name", name)
def get_port_id(self, name):
return self._get_element("port_id", "port", "name", name)
def get_vlan_id_from_name(self, name):
return self._get_element("vlan_id", "vlan", "name", name)
def get_vlan_id_from_tag(self, tag):
return self._get_element("vlan_id", "vlan", "tag", tag)
def get_switch_name(self, switch_id):
return self._get_element("name", "switch", "switch_id", switch_id)
def get_port_name(self, port_id):
return self._get_element("port_name", "port", "port_id", port_id)
def get_vlan_name(self, vlan_id):
return self._get_element("vlan_name", "vlan", "vlan_id", vlan_id)
def _get_row(self, table, field, value):
sql = "SELECT * FROM %s WHERE %s = %s"
data = (table, field, value)
self.cursor.execute(sql, data)
return self.cursor.fetchone()
def get_switch(self, switch_id):
return self._get_row("switch", "switch_id", switch_id)
def get_port(self, port_id):
return self._get_row("port", "port_id", port_id)
def get_vlan(self, vlan_id):
return self._get_row("vlan", "vlan_id", vlan_id)
def set_vlan(self, port_id, vlan_id):
port = self.get_port(port_id)
if port == None:
raise("Port %s does not exist" % port_id)
if port["is_trunk"] or port["is_locked"]:
raise CriticalError("The port is locked")
vlan = self.get_vlan(vlan_id)
if vlan == None:
raise CriticalError("VLAN %s does not exist" % vlan_id)
try:
sql = "UPDATE port SET current_vlan_id=%s WHERE port_id=%s"
data = (vlan_id, port_id)
self.cursor.execute(sql, data)
except:
self.connection.rollback()
raise
def restore_default_vlan(self, port_id):
port = self.get_port(port_id)
if port == None:
raise CriticalError("Port %s does not exist")
if port["is_trunk"] or port["is_locked"]:
raise CriticalError("The port is locked")
try:
sql = "UPDATE port SET current_vlan_id=base_vlan_id WHERE port_id=%d"
data = port_id
self.cursor.execute(sql, data)
except:
self.connection.rollback()
raise
def _dump_table(self, table):
result = []
self.cursor.execute("SELECT * FROM %s" % table)
record = self.cursor.fetchone()
while record != None:
result.append(record)
record = self.cursor.fetchone()
return result
def all_switches(self):
return self._dump_table("switch")
def all_ports(self):
return self._dump_table("port")
def all_vlans(self):
return self._dump_table("vlan")