aboutsummaryrefslogtreecommitdiff
path: root/util.py
blob: 765203703c9b856135f132aff29a8530494b14eb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
import logging
import os
import time
from db.db import VlanDB
from errors import CriticalError, InputError, ConfigError, SocketError

class VlanUtil:
    """VLANd utility functions"""

    def get_switch_driver(self, switch, config):
        logging.debug("Trying to find a driver for %s" % switch)
        driver = config.switches[switch].driver
        logging.debug("Driver: %s" % driver)
        module = __import__("drivers.%s" % driver, fromlist=[driver])
        class_ = getattr(module, driver)
        return class_(switch)

    def get_all_switches(self, config):
        for switch in sorted(config.switches):
            print "Found switch %s:" % (switch)
            print "  Probing:"

            s = util.get_switch_driver(switch, config)
            s.switch_connect(config.switches[switch].username, config.switches[switch].password)
            print "  Found details of switch:"
            s._dump_list(s._systemdata)
            print "  Switch has %d ports:" % len(s.switch_get_port_names())
            for port in s.switch_get_port_names():
                print "  %s" % port
                if 0 == 1:
                    mode = s.port_get_mode(port)
                    if mode == "trunk":
                        print "  port %s is in trunk mode, VLAN(s):" % port
                        vlans = s.port_get_trunk_vlan_list(port)
                        for vlan in vlans:
                            name = s.vlan_get_name(vlan)
                            print "    %d (%s)" % (vlan, name)
                    else:
                        vlan = s.port_get_access_vlan(port)
                        name = s.vlan_get_name(vlan)
                        print "  port %s is in access mode, VLAN %d (%s):" % (port, vlan, name)    
            s.switch_disconnect()
            del(s)

    # Simple helper wrapper for all the read-only database queries
    def perform_db_query(self, state, command, data):
        print 'perform_db_query'
        print command
        print data
        ret = {}
        db = state.db
        try:
            if command == 'db.all_switches':
                ret = db.all_switches()
            elif command == 'db.all_ports':
                ret = db.all_ports()
            elif command == 'db.all_vlans':
                ret = db.all_vlans()
            elif command == 'db.get_switch_by_id':
                ret = db.get_switch_by_id(data['switch_id'])
            elif command == 'db.get_switch_id_by_name':
                ret = db.get_switch_id_by_name(data['name'])
            elif command == 'db.get_switch_name_by_id':
                ret = db.get_switch_name_by_id(data['switch_id'])
            elif command == 'db.get_port_by_id':
                ret = db.get_port_by_id(data['port_id'])
            elif command == 'db.get_ports_by_switch':
                ret = db.get_ports_by_switch(data['switch_id'])
            elif command == 'db.get_port_by_switch_and_name':
                ret = db.get_port_by_switch_and_name(data['switch_id'], data['name'])
            elif command == 'db.get_current_vlan_id_by_port':
                ret = db.get_current_vlan_id_by_port(data['port_id'])
            elif command == 'db.get_base_vlan_id_by_port':
                ret = db.get_base_vlan_id_by_port(data['port_id'])
            elif command == 'db.get_ports_by_current_vlan':
                ret = db.get_ports_by_current_vlan(data['vlan_id'])
            elif command == 'db.get_ports_by_base_vlan':
                ret = db.get_ports_by_base_vlan(data['vlan_id'])
            elif command == 'db.get_vlan_by_id':
                ret = db.get_vlan_by_id(data['vlan_id'])
            elif command == 'db.get_vlan_id_by_name':
                ret = db.get_vlan_id_by_name(data['name'])
            elif command == 'db.get_vlan_id_by_tag':
                ret = db.get_vlan_id_by_tag(data['tag'])
            elif command == 'db.get_vlan_name_by_id':
                ret = db.get_vlan_name_by_id(data['vlan_id'])
            else:
                raise InputError("Unknown db_query command \"%s\"" % command)

        except InputError:
            raise

        except:
            raise InputError("Invalid input in query")

        return ret

    # Simple helper wrapper for all the read-only daemon state queries
    def perform_daemon_query(self, state, command, data):
        print 'perform_daemon_query'
        print command
        print data
        ret = {}
        try:
            if command == 'daemon.status':
                # data ignored
                ret['running'] = 'ok'
            elif command == 'daemon.version':
                # data ignored
                ret['version'] = state.version
            elif command == 'daemon.statistics':
                ret['uptime'] = time.time() - state.starttime
            else:
                raise InputError("Unknown daemon_query command \"%s\"" % command)

        except InputError:
            raise

        except:
            raise InputError("Invalid input in query")

        return ret

    # Helper wrapper for API functions modifying database state only
    def perform_db_update(self, state, command, data):
        print 'perform_db_update'
        print command
        print data
        ret = {}
        db = state.db
        try:
            if command == 'db.create_switch':
                ret = db.create_switch(data['name'])
            elif command == 'db.create_port':
                ret = db.create_port(data['switch_id'], data['name'],
                                     state.config.default_vlan_id,
                                     state.config.default_vlan_id)
            elif command == 'db.delete_switch':
                ret = db.delete_switch(data['switch_id'])
            elif command == 'db.delete_port':
                ret = db.delete_port(data['port_id'])
            elif command == 'db.set_port_is_locked':
                ret = db.set_port_is_locked(data['port_id'], data['is_locked'])
            elif command == 'db.set_base_vlan':
                ret = db.set_base_vlan(data['port_id'], data['base_vlan_id'])
            else:
                raise InputError("Unknown db_update command \"%s\"" % command)

        except InputError:
            raise

        except:
            raise InputError("Invalid input in query")

        return ret

    # Helper wrapper for API functions that modify both database state
    # and on-switch VLAN state
    def perform_vlan_update(self, state, command, data):
        print 'perform_vlan_update'
        print command
        print data
        ret = {}
        db = state.db
        try:
            # All of these are complex commands, so call helpers
            # rather than inline the code here
            if command == 'api.create_vlan':
                ret = self.create_vlan(state, command, data)
            elif command == 'api.delete_vlan':
                ret = self.delete_vlan(state, command, data)
            elif command == 'api.set_port_mode':
                ret = self.set_port_mode(state, command, data)
            elif command == 'api.set_current_vlan':
                ret = self.set_current_vlan(state, command, data)
            elif command == 'api.restore_base_vlan':
                ret = self.restore_base_vlan(state, command, data)
            else:
                raise InputError("Unknown query command \"%s\"" % command)

        except InputError as e:
            print 'got error %s' % e
            raise

        except:
            raise InputError("Invalid input in query")

        return ret


    # Complex call
    # 1. create the VLAN in the DB
    # 2. Iterate through all switches:
    #    a. Create the VLAN
    #    b. Add the VLAN to all trunk ports (if needed)
    # 3. If all went OK, save config on all the switches
    #
    # The VLAN may already exist on some of the switches, that's
    # fine. If things fail, we attempt to roll back by rebooting
    # switches then removing the VLAN in the DB.
    def create_vlan(self, state, command, data):

        print 'create_vlan'
        db = state.db
        config = state.config

        name = data['name']
        tag = int(data['tag'])
        is_base_vlan = data['is_base_vlan']

        # 1. Database record first
        try:
            print 'Adding DB record first: name %s, tag %d, is_base_vlan %d' % (name, tag, is_base_vlan)
            vlan_id = db.create_vlan(name, tag, is_base_vlan)
            print 'Added VLAN tag %d, name %s to the database, created VLAN ID %d' % (tag, name, vlan_id)
        except InputError:
            print 'DB creation failed'
            raise

        # Keep track of which switches we've configured, for later use
        switches_done = []

        # 2. Now the switches
        try:
            for switch in sorted(config.switches):
                trunk_ports = []
                try:
                    print 'Adding new VLAN to switch %s' % switch
                    # Get the right driver
                    s = self.get_switch_driver(switch, config)
                    s.switch_connect(config.switches[switch].username, config.switches[switch].password)

                    # Mark this switch as one we've touched, for
                    # either config saving or rollback below
                    switches_done.append(switch)

                    # 2a. Create the VLAN on the switch
                    s.vlan_create(tag)
                    s.vlan_set_name(tag, name)
                    print 'Added VLAN tag %d, name %s to switch %s' % (tag, name, switch)

                    # 2b. Do we need to worry about trunk ports on this switch?
                    if 'TrunkWildCardVlans' in s.switch_get_capabilities():
                        print 'This switch does not need special trunk port handling'
                    else:
                        print 'This switch needs special trunk port handling'
                        switch_id = db.get_switch_id_by_name(switch)
                        trunk_ports = db.get_trunk_port_names_by_switch(switch_id)
                        if trunk_ports is None:
                            print "But it has no trunk ports defined"
                            trunk_ports = []
                        else:
                            print 'Found %d trunk_ports that need adjusting' % len(trunk_ports)

                    # Modify any trunk ports as needed
                    for port in trunk_ports:
                        print 'Added VLAN tag %d, name %s to switch %s' % (tag, name, switch)
                        s.port_add_trunk_to_vlan(port, tag)

                    # And now we're done with this switch
                    s.switch_disconnect()
                    del s

                except IOError:
                    raise
        
        except IOError:
            # Bugger. Looks like one of the switch calls above
            # failed. To undo the changes safely, we'll need to reset
            # all the switches we managed to configure. This could
            # take some time!
            for switch in switches_done:
                s = self.get_switch_driver(switch, config)
                s.switch_connect(config.switches[switch].username, config.switches[switch].password)
                s.switch_restart() # Will implicitly also close the connection
                del s

            # Undo the database change
            print 'Switch access failed. Deleting the new VLAN entry in the database'
            db.delete_vlan(vlan_id)
            raise

        # If we've got this far, things were successful. Save config
        # on all the switches so it will persist across reboots
        for switch in switches_done:
            s = self.get_switch_driver(switch, config)
            s.switch_connect(config.switches[switch].username, config.switches[switch].password)
            s.switch_save_running_config()
            s.switch_disconnect()
            del s

        return vlan_id # If we're successful

    # Complex call
    # 1. Check in the DB if there are any ports on the VLAN. Bail if so
    # 2. Iterate through all switches:
    #    a. Remove the VLAN from all trunk ports (if needed)
    #    b. Remove the VLAN
    # 3. If all went OK, save config on the switches
    # 4. Remove the VLAN in the DB
    #
    # If things fail, we attempt to roll back by rebooting switches.
    def delete_vlan(self, state, command, data):

        print 'delete_vlan'
        db = state.db
        config = state.config

        vlan_id = int(data['vlan_id'])

        # 1. Check for database records first
        print 'Checking for ports using VLAN id %d' % vlan_id
        vlan = db.get_vlan_by_id(vlan_id)
        if vlan is None:
            raise InputError("VLAN ID %d does not exist" % vlan_id)
        vlan_tag = vlan.tag
        ports = db.get_ports_by_current_vlan(vlan_id)
        if ports is not None:
            raise InputError("Cannot delete VLAN ID %d when it still has %d ports" %
                             (vlan_id, len(ports)))
        ports = db.get_ports_by_base_vlan(vlan_id)
        if ports is not None:
            raise InputError("Cannot delete VLAN ID %d when it still has %d ports" %
                             (vlan_id, len(ports)))

        # Keep track of which switches we've configured, for later use
        switches_done = []

        # 2. Now the switches
        try:
            for switch in sorted(config.switches):
                trunk_ports = []
                try:
                    # Get the right driver
                    s = self.get_switch_driver(switch, config)
                    s.switch_connect(config.switches[switch].username, config.switches[switch].password)

                    # Mark this switch as one we've touched, for
                    # either config saving or rollback below
                    switches_done.append(switch)

                    # 2a. Do we need to worry about trunk ports on this switch?
                    if 'TrunkWildCardVlans' in s.switch_get_capabilities():
                        print 'This switch does not need special trunk port handling'
                    else:
                        print 'This switch needs special trunk port handling'
                        switch_id = db.get_switch_id_by_name(switch)
                        trunk_ports = db.get_trunk_port_names_by_switch(switch_id)
                        if trunk_ports is None:
                            print "But it has no trunk ports defined"
                            trunk_ports = []
                        else:
                            print 'Found %d trunk_ports that need adjusting' % len(trunk_ports)

                    # Modify any trunk ports as needed
                    for port in trunk_ports:
                        s.port_remove_trunk_from_vlan(port, tag)
                        print 'Removed VLAN tag %d from switch %s port %s' % (vlan_tag, switch, port)

                    # 2b. Remove the VLAN from the switch
                    print 'Removing VLAN tag %d from switch %s' % (vlan_tag, switch)
                    s.vlan_destroy(vlan_tag)
                    print 'Removed VLAN tag %d from switch %s' % (vlan_tag, switch)

                    # And now we're done with this switch
                    s.switch_disconnect()
                    del s

                except IOError:
                    raise

        except IOError:
            # Bugger. Looks like one of the switch calls above
            # failed. To undo the changes safely, we'll need to reset
            # all the switches we managed to configure. This could
            # take some time!
            for switch in switches_done:
                s = self.get_switch_driver(switch, config)
                s.switch_connect(config.switches[switch].username, config.switches[switch].password)
                s.switch_restart() # Will implicitly also close the connection
                del s

        # 3. If we've got this far, things were successful. Save
        # config on all the switches so it will persist across reboots
        for switch in switches_done:
            s = self.get_switch_driver(switch, config)
            s.switch_connect(config.switches[switch].username, config.switches[switch].password)
            s.switch_save_running_config()
            s.switch_disconnect()
            del s

        # 4. Finally, remove the VLAN in the DB
        try:
            print 'Removing DB record: VLAN id %d' % vlan_id
            vlan_id = db.delete_vlan(vlan_id)
            print 'Removed VLAN id %d from the database OK' % vlan_id
        except InputError:
            print 'DB deletion failed'
            raise

        return vlan_id # If we're successful