Add trunks, simple containers to track inter-switch connections

To help with visualisation, add trunks - containers to describe
inter-switch connections in the database.

This entails:

 * a new table in the database called 'trunk' that contains nothing
   but creation time and the trunk_id field.
 * new trunk_id field in the port table, so that ports can either
   belong to one trunk or to trunk # -1 (aka none)

Creating a trunk then entails:

 * Creating a new entry in the trunk table
 * Moving both ports to be attached to that trunk_id

Also added helper functions and admin functions to match.

Also plumbed through --lookup_ports_by_switch,
--lookup_ports_by_current_vlan, --lookup_ports_by_base_vlan.

Change-Id: I97f7aa9a14eecbfab9a57f5e776ad21c5944b369
diff --git a/db/db.py b/db/db.py
index aeb65a1..99ea454 100644
--- a/db/db.py
+++ b/db/db.py
@@ -1,7 +1,8 @@
 #! /usr/bin/python
 
-#  Copyright 2014 Linaro Limited
-#  Author: Dave Pigott <dave.pigott@linaro.org>
+#  Copyright 2014-2015 Linaro Limited
+#  Authors: Dave Pigott <dave.pigott@linaro.org>,
+#           Steve McIntyre <steve.mcintyre@linaro.org>
 #
 #  This program is free software; you can redistribute it and/or modify
 #  it under the terms of the GNU General Public License as published by
@@ -23,6 +24,8 @@
 import datetime, os, sys
 import logging
 
+TRUNK_ID_NONE = -1
+
 if __name__ == '__main__':
     vlandpath = os.path.abspath(os.path.normpath(os.path.dirname(sys.argv[0])))
     sys.path.insert(0, vlandpath)
@@ -66,13 +69,13 @@
 
         return switch_id
 
-    # Create a new port in the database. Two of the fields are created
-    # with default values (is_locked, is_trunk) here, and should be
-    # updated separately if desired. For the current_vlan_id and
-    # base_vlan_id fields, *BE CAREFUL* that you have already looked
-    # up the correct VLAN_ID for each. This is *NOT* the same as the
-    # VLAN tag (likely to be 1).
-    # You Have Been Warned!
+    # Create a new port in the database. Three of the fields are
+    # created with default values (is_locked, is_trunk, trunk_id)
+    # here, and should be updated separately if desired. For the
+    # current_vlan_id and base_vlan_id fields, *BE CAREFUL* that you
+    # have already looked up the correct VLAN_ID for each. This is
+    # *NOT* the same as the VLAN tag (likely to be 1). You Have Been
+    # Warned!
     #
     # Constraints:
     # 1. The switch referred to must already exist
@@ -99,10 +102,10 @@
             raise InputError("Already have a port %d on switch ID %d" % (int(number), int(switch_id)))
 
         try:
-            sql = "INSERT INTO port (name, number, switch_id, is_locked, is_trunk, current_vlan_id, base_vlan_id) VALUES (%s, %s, %s, %s, %s, %s, %s) RETURNING port_id"
+            sql = "INSERT INTO port (name, number, switch_id, is_locked, is_trunk, current_vlan_id, base_vlan_id, trunk_id) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) RETURNING port_id"
             data = (name, number, switch_id,
                     False, False,
-                    current_vlan_id, base_vlan_id)
+                    current_vlan_id, base_vlan_id, TRUNK_ID_NONE)
             self.cursor.execute(sql, data)
             port_id = self.cursor.fetchone()[0]
             self.connection.commit()
@@ -149,6 +152,45 @@
 
         return vlan_id
 
+    # Create a new trunk in the database, linking two ports. Trunks
+    # are really simple objects for our use - they're just containers
+    # for 2 ports.
+    #
+    # Constraints:
+    # 1. Both ports listed must already exist.
+    # 2. Both ports must be in trunk mode.
+    # 3. Both must not be locked.
+    # 4. Both must not already be in a trunk.
+    def create_trunk(self, port_id1, port_id2):
+
+        for port_id in (port_id1, port_id2):
+            port = self.get_port_by_id(int(port_id))
+            if port is None:
+                raise InputError("Port ID %d does not exist" % int(port_id))
+            if not port.is_trunk:
+                raise InputError("Port ID %d is not in trunk mode" % int(port_id))
+            if port.is_locked:
+                raise InputError("Port ID %d is locked" % int(port_id))
+            if port.trunk_id != TRUNK_ID_NONE:
+                raise InputError("Port ID %d is already on trunk ID %d" % (int(port_id), int(port.trunk_id)))
+
+        try:
+            # Add the trunk itself
+            dt = datetime.datetime.now()
+            sql = "INSERT INTO trunk (creation_time) VALUES (%s) RETURNING trunk_id"
+            data = (dt, )
+            self.cursor.execute(sql, data)
+            trunk_id = self.cursor.fetchone()[0]
+            self.connection.commit()
+            # And update the ports
+            for port_id in (port_id1, port_id2):
+                self._set_port_trunk(port_id, trunk_id)
+        except:
+            self.delete_trunk(trunk_id)
+            raise
+
+        return trunk_id
+
     # Internal helper function
     def _delete_row(self, table, field, value):
         try:
@@ -211,6 +253,22 @@
         self._delete_row("vlan", "vlan_id", vlan_id)
         return vlan_id
 
+    # Delete the specified trunk
+    #
+    # Constraints:
+    # 1. The trunk must exist
+    #
+    # Any ports attached will be detached (i.e. moved to trunk TRUNK_ID_NONE)
+    def delete_trunk(self, trunk_id):
+        trunk = self.get_trunk_by_id(trunk_id)
+        if trunk is None:
+            raise InputError("Trunk ID %d does not exist" % int(trunk_id))
+        ports = self.get_ports_by_trunk(trunk_id)
+        for port_id in ports:
+            self._set_port_trunk(port_id, TRUNK_ID_NONE)
+        self._delete_row("trunk", "trunk_id", trunk_id)
+        return trunk_id
+
     # Find the lowest unused VLAN tag and return it
     #
     # Constraints:
@@ -419,6 +477,13 @@
     def get_ports_by_base_vlan(self, vlan_id):
         return self._get_multi_elements("port_id", "port", "base_vlan_id", int(vlan_id))
 
+    # Simple lookup: look up a trunk by ID, and return the IDs of the
+    # ports on both ends of that trunk.
+    #
+    # Returns None on failure.
+    def get_ports_by_trunk(self, trunk_id):
+        return self._get_multi_elements("port_id", "port", "trunk_id", int(trunk_id))
+
     # Simple lookup: look up a VLAN by ID, and return all the details
     # of that VLAN.
     #
@@ -454,6 +519,13 @@
     def get_vlan_tag_by_id(self, vlan_id):
         return self._get_element("tag", "vlan", "vlan_id", int(vlan_id))
 
+    # Simple lookup: look up a trunk by ID, and return all the details
+    # of that trunk.
+    #
+    # Returns None on failure.
+    def get_trunk_by_id(self, trunk_id):
+        return self._get_row("trunk", "trunk_id", int(trunk_id))
+
     # Grab one row of a query on one column; useful as a quick wrapper
     def _get_row(self, table, field, value):
 
@@ -580,6 +652,28 @@
             raise
         return port_id
 
+    # Internal function: Attach a port to a trunk in the database.
+    #
+    # Constraints:
+    # 1. The port must already exist
+    # 2. The port must not be locked
+    def _set_port_trunk(self, port_id, trunk_id):
+        port = self.get_port_by_id(port_id)
+        if port is None:
+            raise InputError("Port ID %d does not exist" % int(port_id))
+        if port.is_locked:
+            raise CriticalError("The port is locked")
+        try:
+            sql = "UPDATE port SET trunk_id=%s WHERE port_id=%s RETURNING port_id"
+            data = (int(trunk_id), int(port_id))
+            self.cursor.execute(sql, data)
+            port_id = self.cursor.fetchone()[0]
+            self.connection.commit()
+        except:
+            self.connection.rollback()
+            raise
+        return port_id
+
     # Trivial helper function to return all the rows in a given table
     def _dump_table(self, table):
         result = []
@@ -599,6 +693,9 @@
     def all_vlans(self):
         return self._dump_table("vlan")
 
+    def all_trunks(self):
+        return self._dump_table("trunk")
+
 if __name__ == '__main__':
     db = VlanDB()
     s = db.all_switches()
@@ -610,5 +707,8 @@
     v = db.all_vlans()
     print 'The DB knows about %d vlan(s)' % len(v)
     print v
+    t = db.all_trunks()
+    print 'The DB knows about %d trunks(s)' % len(t)
+    print t
 
     print 'First free VLAN tag is %d' % db.find_lowest_unused_vlan_tag()