blob: 5961ed11e855d4954648e5b0d7c8511eab5102a2 [file] [log] [blame]
# Copyright 2014-2015 Linaro Limited
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
# MA 02110-1301, USA.
#
# Simple VLANd IPC module
import socket
import json
import time
import datetime
import os
import sys
import logging
from Vland.errors import CriticalError, InputError, ConfigError, SocketError
class VlanIpc:
"""VLANd IPC class"""
def __init__(self):
self.conn = None
self.socket = None
def server_init(self, host, port):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.conn = None
while True:
try:
self.socket.bind((host, port))
break
except socket.error as e:
print "Can't bind to port %d: %s" % (port, e)
time.sleep(1)
def server_listen(self):
if self.socket is None:
raise SocketError("Server can't receive data: no socket")
self.socket.listen(1)
def server_recv(self):
if self.socket is None:
raise SocketError("Server can't receive data: no socket")
self.conn, addr = self.socket.accept()
logging.debug("server: Connection from")
logging.debug(addr)
data = self.conn.recv(8) # 32bit limit
count = int(data, 16)
c = 0
data = ''
while c < count:
data += self.conn.recv(1)
c += 1
try:
json_data = json.loads(data)
except ValueError:
self.conn.close()
self.conn = None
raise SocketError("Server unable to decode receieved data: corrupt?")
if 'client_name' not in json_data:
self.conn.close()
self.conn = None
raise SocketError("Server unable to detect client name: corrupt packet?")
return json_data
def server_reply(self, json_data):
if self.conn is None:
raise SocketError("Server can't send data: no connection")
data = self._format_message(json_data)
if not data:
self.conn.close()
self.conn = None
raise SocketError("Server unable to format reply data")
try:
# send the actual number of bytes to read.
self.conn.send(data[0])
# now send the bytes.
self.conn.send(data[1])
except socket.error as e:
logging.error("Can't send response to client: %s", e)
logging.error("Was trying to send data:")
logging.error(data)
def server_close(self):
if self.conn is not None:
self.conn.shutdown(socket.SHUT_RDWR)
self.conn.close()
def client_connect(self, host, port):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
while True:
try:
ret = self.socket.connect_ex((host, port))
if ret:
self.socket.close()
self.socket = None
raise SocketError("Client can't send connect: %s" % ret)
else:
break
except socket.error:
time.sleep(1)
return True
def client_send(self, json_data):
if self.socket is None:
raise SocketError("Client can't send data: no socket")
data = self._format_message(json_data)
if not data:
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
self.socket = None
raise SocketError("Client unable to send data")
# send the actual number of bytes to read.
self.socket.send(data[0])
# now send the bytes.
self.socket.send(data[1])
def client_recv_and_close(self):
if self.socket is None:
raise SocketError("Client can't receieve data: no socket")
data = self.socket.recv(8) # 32bit limit
count = int(data, 16)
c = 0
data = ''
while c < count:
data += self.socket.recv(1)
c += 1
try:
json_data = json.loads(data)
except ValueError:
self.socket.close()
self.socket = None
raise SocketError("Client unable to decode receieved data: corrupt?")
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
return json_data
# The default JSON serialiser code can't deal with datetime
# objects by default, so let's tell it how to.
def _json_serial(self, obj):
"""JSON serializer for objects not serialisable by default json code"""
if isinstance(obj, datetime.datetime):
serial = obj.isoformat()
return serial
def _format_message(self, json_data):
try:
msgstr = json.dumps(json_data, default=self._json_serial)
except ValueError:
return None
# "header" calculation
msglen = "%08X" % len(msgstr)
return (msglen, msgstr)