aboutsummaryrefslogtreecommitdiff
path: root/db/db.py
blob: 57d3107b438d4c39fdcaa0b9fb2fd233e7eaadcd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
#! /usr/bin/python

#  Copyright 2014 Linaro Limited
#  Author: Dave Pigott <dave.pigott@linaro.org>
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 2 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program; if not, write to the Free Software
#  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
#  MA 02110-1301, USA.

import psycopg2
import psycopg2.extras
import datetime
from errors import CriticalError

class VlanDB:
    def __init__(self, db_name="vland", username="vland"):
        try:
            self.connection = psycopg2.connect(database=db_name, user=username)
            self.cursor = self.connection.cursor(cursor_factory=psycopg2.extras.DictCursor)
        except Exception as e:
            print "Failed to access database: %s" % e

    def __del__(self):
        self.cursor.close()
        self.connection.close()

    # Create a new switch in the database
    def create_switch(self, name):
        try:
            sql = "INSERT INTO switch (name) VALUES (%s) RETURNING switch_id"
            data = name
            self.cursor.execute(sql, data)
            switch_id = self.cursor.fetchone()[0]
            self.connection.commit()
        except:
            self.connection.rollback()
            raise
        return switch_id

    # Create a new port in the database. Two of the fields are created
    # with default values (is_locked, is_trunk) here, and should be
    # updated separately if desired. For the current_vlan_id and
    # base_vlan_id fields, *BE CAREFUL* that you have already looked
    # up the correct VLAN_ID for each. This is *NOT* the same as the
    # VLAN tag (likely to be 1).
    # You Have Been Warned!
    def create_port(self, name, switch_id, current_vlan_id, base_vlan_id):
        try:
            sql = "INSERT INTO port (name, switch_id, is_locked, is_trunk, current_vlan_id, base_vlan_id) VALUES (%s, %s, %s, %s, %s, %s) RETURNING port_id"
            data = (name, switch_id,
                    False, False,
                    current_vlan_id, base_vlan_id)
            self.cursor.execute(sql, data)
            port_id = self.cursor.fetchone()[0]
            self.connection.commit()
        except:
            self.connection.rollback()
            raise
        return port_id

    # Create a new vlan in the database. We locally add a creation
    # timestamp, for debug purposes. If vlans seems to be sticking
    # around, we'll be able to see when they were created.
    def create_vlan(self, name, tag, is_base_vlan):
        try:
            dt = datetime.datetime.now()
            sql = "INSERT INTO vlan (name, tag, is_base_vlan, creation_time) VALUES (%s, %s, %s, %s) RETURNING vlan_id"
            data = (name, tag, is_base_vlan, dt)
            self.cursor.execute(sql, data)
            vlan_id = self.cursor.fetchone()[0]
            self.connection.commit()
        except:
            self.connection.rollback()
            raise
        return vlan_id

    def _delete_row(self, table, field, value):
        try:
            sql = "DELETE FROM %s WHERE %s = %s"
            data = (table, field, value)
            self.cursor.execute(sql, data)
            self.connection.commit()
        except:
            self.connection.rollback()
            raise

    def delete_switch(self, switch_id):
        self._delete_row("switch", "switch_id", switch_id)

    def delete_port(self, port_id):
        self._delete_row("port", "port_id", port_id)

    def delete_vlan(self, vlan_id):
        self._delete_row("vlan", "vlan_id", vlan_id)

    def _get_element(self, select_field, table, compare_field, value):

        # We really want to use psycopg's type handling deal with the
        # (potentially) user-supplied data in the value field, so we
        # have to pass (sql,data) through to cursor.execute. However,
        # we can't have psycopg do all the argument substitution here
        # as it will quote all the params like the table name. That
        # doesn't work. So, we substitute a "%s" for "%s" here so we
        # keep it after python's own string substitution.
        sql = "SELECT %s FROM %s WHERE %s = %s" % (select_field, table, compare_field, "%s")

        # Now, the next icky thing: we need to make sure that we're
        # passing a dict so that psycopg2 can pick it apart properly
        # for its own substitution code. We force this with the
        # trailing comma here
        data = (value, )
        self.cursor.execute(sql, data)

        # Will raise an exception here if there are no rows that
        # match. That's OK - the caller needs to deal with that.
        return self.cursor.fetchone()[0]

    def get_switch_id(self, name):
        return self._get_element("switch_id", "switch", "name", name)

    def get_port_id(self, name):
        return self._get_element("port_id", "port", "name", name)

    def get_vlan_id_from_name(self, name):
        return self._get_element("vlan_id", "vlan", "name", name)

    def get_vlan_id_from_tag(self, tag):
        return self._get_element("vlan_id", "vlan", "tag", tag)

    def get_switch_name(self, switch_id):
        return self._get_element("name", "switch", "switch_id", switch_id)

    def get_port_name(self, port_id):
        return self._get_element("port_name", "port", "port_id", port_id)

    def get_vlan_name(self, vlan_id):
        return self._get_element("vlan_name", "vlan", "vlan_id", vlan_id)

    def _get_row(self, table, field, value):
        sql = "SELECT * FROM %s WHERE %s = %s"
        data = (table, field, value)
        self.cursor.execute(sql, data)
        return self.cursor.fetchone()

    def get_switch(self, switch_id):
        return self._get_row("switch", "switch_id", switch_id)

    def get_port(self, port_id):
        return self._get_row("port", "port_id", port_id)

    def get_vlan(self, vlan_id):
        return self._get_row("vlan", "vlan_id", vlan_id)

    # (Un)Lock a port in the database. This can only be done through
    # the admin interface, and will stop API users from modifying
    # settings on the port. Use this to lock down ports that are used
    # for PDUs and other core infrastructure
    def set_port_is_locked(self, port_id, is_locked):
        try:
            sql = "UPDATE port SET is_locked=%s WHERE port_id=%s"
            data = (is_locked, port_id)
            self.cursor.execute(sql, data)
            port_id = self.cursor.fetchone()[0]
            self.connection.commit()
        except:
            self.connection.rollback()
            raise
        return port_id        

    def set_vlan(self, port_id, vlan_id):
        port = self.get_port(port_id)
        if port == None:
            raise("Port %s does not exist" % port_id)

        if port["is_trunk"] or port["is_locked"]:
            raise CriticalError("The port is locked")

        vlan = self.get_vlan(vlan_id)
        if vlan == None:
            raise CriticalError("VLAN %s does not exist" % vlan_id)

        try:
            sql = "UPDATE port SET current_vlan_id=%s WHERE port_id=%s"
            data = (vlan_id, port_id)
            self.cursor.execute(sql, data)
        except:
            self.connection.rollback()
            raise

    def restore_default_vlan(self, port_id):
        port = self.get_port(port_id)
        if port == None:
            raise CriticalError("Port %s does not exist")

        if port["is_trunk"] or port["is_locked"]:
            raise CriticalError("The port is locked")

        try:
            sql = "UPDATE port SET current_vlan_id=base_vlan_id WHERE port_id=%d"
            data = port_id
            self.cursor.execute(sql, data)
        except:
            self.connection.rollback()
            raise

    def _dump_table(self, table):
        result = []
        self.cursor.execute("SELECT * FROM %s" % table)
        record = self.cursor.fetchone()
        while record != None:
            result.append(record)
            record = self.cursor.fetchone()
        return result

    def all_switches(self):
        return self._dump_table("switch")

    def all_ports(self):
        return self._dump_table("port")

    def all_vlans(self):
        return self._dump_table("vlan")