Commit e69aa084 by Szeberényi Imre

fix: 2<->3, win-virtio, new uppdate scheme

parent ba6ba173
...@@ -8,6 +8,7 @@ import win32event ...@@ -8,6 +8,7 @@ import win32event
import win32service import win32service
import win32serviceutil import win32serviceutil
#import agent
from agent import main as agent_main, reactor from agent import main as agent_main, reactor
logger = logging.getLogger() logger = logging.getLogger()
...@@ -17,7 +18,7 @@ formatter = logging.Formatter( ...@@ -17,7 +18,7 @@ formatter = logging.Formatter(
"%(asctime)s - %(name)s [%(levelname)s] %(message)s") "%(asctime)s - %(name)s [%(levelname)s] %(message)s")
fh.setFormatter(formatter) fh.setFormatter(formatter)
logger.addHandler(fh) logger.addHandler(fh)
level = os.environ.get('LOGLEVEL', 'INFO') level = os.environ.get('LOGLEVEL', 'DEBUG')
logger.setLevel(level) logger.setLevel(level)
logger.info("%s loaded", __file__) logger.info("%s loaded", __file__)
...@@ -46,7 +47,7 @@ class AppServerSvc (win32serviceutil.ServiceFramework): ...@@ -46,7 +47,7 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
def main(): def main():
if len(sys.argv) == 1: if len(sys.argv) == 0: # never happen set 1 for debugging!!!!
# service must be starting... # service must be starting...
# for the sake of debugging etc, we use win32traceutil to see # for the sake of debugging etc, we use win32traceutil to see
# any unhandled exceptions and print statements. # any unhandled exceptions and print statements.
......
...@@ -6,24 +6,13 @@ import platform ...@@ -6,24 +6,13 @@ import platform
import subprocess import subprocess
import sys import sys
system = platform.system() # noqa
if system == "Linux" or system == "FreeBSD": # noqa
try: # noqa
chdir(sys.path[0]) # noqa
subprocess.call(('pip', 'install', '-r', 'requirements.txt')) # noqa
except Exception: # noqa
pass # hope it works # noqa
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
from twisted.internet.task import LoopingCall from twisted.internet.task import LoopingCall
import uptime import uptime
import logging import logging
from inspect import getargspec, isfunction from inspect import getargs, isfunction
from utils import SerialLineReceiverBase from utils import SerialLineReceiverBase
...@@ -31,9 +20,37 @@ from utils import SerialLineReceiverBase ...@@ -31,9 +20,37 @@ from utils import SerialLineReceiverBase
# (relative import error. # (relative import error.
from context import BaseContext, get_context, get_serial # noqa from context import BaseContext, get_context, get_serial # noqa
try:
# Python 2: "unicode" is built-in
unicode
except NameError:
unicode = str
try:
from inspect import getfullargspec as getargspec
except ImportError:
from inspect import getargspec as getargspec
def foo(a, *args, **kwargs):
pass
system = platform.system() # noqa
if system == "Linux" or system == "FreeBSD": # noqa
try: # noqa
chdir(sys.path[0]) # noqa
subprocess.call(('pip', 'install', '-r', 'requirements.txt')) # noqa
except Exception: # noqa
pass # hope it works # noqa
Context = get_context() Context = get_context()
logging.basicConfig() logging.basicConfig(
format="[%(asctime)s] %(levelname)s [agent %(process)d/%(thread)d] %(module)s.%(funcName)s:%(lineno)d] %(message)s",
datefmt="%d/%b/%Y %H:%M:%S",
)
logger = logging.getLogger() logger = logging.getLogger()
level = environ.get('LOGLEVEL', 'INFO') level = environ.get('LOGLEVEL', 'INFO')
...@@ -45,35 +62,42 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -45,35 +62,42 @@ class SerialLineReceiver(SerialLineReceiverBase):
def __init__(self): def __init__(self):
super(SerialLineReceiver, self).__init__() super(SerialLineReceiver, self).__init__()
self.tickId = LoopingCall(self.tick) self.tickId = LoopingCall(self.tick)
self.mayStartNowId = LoopingCall(self.mayStartNow)
reactor.addSystemEventTrigger("before", "shutdown", self.shutdown) reactor.addSystemEventTrigger("before", "shutdown", self.shutdown)
self.running = True self.running = True
def connectionMade(self): def connectionMade(self):
logger.debug("connectionMade") logger.debug("connectionMade")
self.clearLineBuffer()
self.tickId.start(5, now=False) self.tickId.start(5, now=False)
self.transport.dataBuffer = b"" self.mayStartNowId.start(10, now=False)
self.transport._tempDataBuffer = [] # will be added to dataBuffer in doWrite self.send_startMsg()
self.transport._tempDataLen = 0
self.transport.write('\r\n')
if self.running:
self.send_command(
command='agent_started',
args={'version': Context.get_agent_version(), 'system': system})
def connectionLost(self, reason): def connectionLost(self, reason):
logger.debug("connectionLost") logger.debug("connectionLost")
if self.tickId.running:
self.tickId.stop() self.tickId.stop()
if self.mayStartNowId.running:
self.mayStartNowId.stop()
def connectionLost2(self, reason): def connectionLost2(self, reason):
self.send_command(command='agent_stopped', self.send_command(command='agent_stopped',
args={}) args={})
def mayStartNow(self):
if BaseContext.placed:
self.mayStartNowId.stop()
logger.info("Placed")
return
self.send_startMsg()
def tick(self): def tick(self):
logger.debug("Sending tick") logger.debug("Sending tick")
try: try:
self.send_status() self.send_status()
except Exception: except Exception as e:
logger.exception("Twisted hide exception") logger.debug("Exception durig tick: %s" % e)
# logger.exception("Twisted hide exception")
def shutdown(self): def shutdown(self):
self.running = False self.running = False
...@@ -83,6 +107,18 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -83,6 +107,18 @@ class SerialLineReceiver(SerialLineReceiverBase):
reactor.callLater(0.3, d.callback, "1") reactor.callLater(0.3, d.callback, "1")
return d return d
def send_startMsg(self):
logger.debug("Sending start message...")
# Hack for flushing the lower level buffersr
self.transport.dataBuffer = b""
self.transport._tempDataBuffer = [] # will be added to dataBuffer in doWrite
self.transport._tempDataLen = 0
self.transport.write('\r\n')
if self.running:
self.send_command(
command='agent_started',
args={'version': Context.get_agent_version(), 'system': system})
def send_status(self): def send_status(self):
import psutil import psutil
disk_usage = dict((disk.device.replace('/', '_'), disk_usage = dict((disk.device.replace('/', '_'),
...@@ -98,6 +134,7 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -98,6 +134,7 @@ class SerialLineReceiver(SerialLineReceiverBase):
logger.debug("send_status finished") logger.debug("send_status finished")
def _check_args(self, func, args): def _check_args(self, func, args):
logger.debug("_check_args %s %s" % (func, args))
if not isinstance(args, dict): if not isinstance(args, dict):
raise TypeError("Arguments should be all keyword-arguments in a " raise TypeError("Arguments should be all keyword-arguments in a "
"dict for command %s instead of %s." % "dict for command %s instead of %s." %
...@@ -105,7 +142,11 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -105,7 +142,11 @@ class SerialLineReceiver(SerialLineReceiverBase):
# check for unexpected keyword arguments # check for unexpected keyword arguments
argspec = getargspec(func) argspec = getargspec(func)
if argspec.keywords is None: # _operation doesn't take ** args try:
_kwargs = argspec.keywords
except AttributeError:
_kwargs = argspec.varkw
if _kwargs is None: # _operation doesn't take ** args
unexpected_kwargs = set(args) - set(argspec.args) unexpected_kwargs = set(args) - set(argspec.args)
if unexpected_kwargs: if unexpected_kwargs:
raise TypeError( raise TypeError(
...@@ -119,10 +160,11 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -119,10 +160,11 @@ class SerialLineReceiver(SerialLineReceiverBase):
if missing_kwargs: if missing_kwargs:
raise TypeError("Command %s missing arguments: %s" % ( raise TypeError("Command %s missing arguments: %s" % (
self._pretty_fun(func), ", ".join(missing_kwargs))) self._pretty_fun(func), ", ".join(missing_kwargs)))
logger.debug("_check_args finished")
def _get_command(self, command, args): def _get_command(self, command, args):
logger.debug("_get_command %s %s" % (command, args)) logger.debug("_get_command %s %s" % (command, args))
if not isinstance(command, basestring) or command.startswith('_'): if not isinstance(command, unicode) or command.startswith('_'):
raise AttributeError(u'Invalid command: %s' % command) raise AttributeError(u'Invalid command: %s' % command)
try: try:
func = getattr(Context, command) func = getattr(Context, command)
...@@ -153,7 +195,9 @@ class SerialLineReceiver(SerialLineReceiverBase): ...@@ -153,7 +195,9 @@ class SerialLineReceiver(SerialLineReceiverBase):
def handle_command(self, command, args): def handle_command(self, command, args):
logger.debug("handle_command %s %s" % (command, args)) logger.debug("handle_command %s %s" % (command, args))
func = self._get_command(command, args) func = self._get_command(command, args)
logger.debug("Call cmd: %s %s" % (func, args))
retval = func(**args) retval = func(**args)
logger.debug("Retval: %s" % retval)
self.send_response( self.send_response(
response=func.__name__, response=func.__name__,
args={'retval': retval, 'uuid': args.get('uuid', None)}) args={'retval': retval, 'uuid': args.get('uuid', None)})
......
""" This is the defautl context file. It replaces the Context class """ This is the defautl context file. It replaces the Context class
to the platform specific one. to the platform specific one.
""" """
import logging
import platform import platform
logger = logging.getLogger()
def _get_virtio_device(): def _get_virtio_device():
path = None path = None
...@@ -18,6 +20,8 @@ def _get_virtio_device(): ...@@ -18,6 +20,8 @@ def _get_virtio_device():
i.children[0].instance_id.lower().replace('\\', '#') + i.children[0].instance_id.lower().replace('\\', '#') +
"#" + GUID.lower() "#" + GUID.lower()
) )
break
logger.debug("DEV found: %s", path)
return path return path
...@@ -36,6 +40,7 @@ def get_context(): ...@@ -36,6 +40,7 @@ def get_context():
def get_serial(): def get_serial():
system = platform.system() system = platform.system()
logger.debug("Get_serial system: %s", system)
port = None port = None
if system == 'Windows': if system == 'Windows':
port = _get_virtio_device() port = _get_virtio_device()
...@@ -70,6 +75,8 @@ def get_serial(): ...@@ -70,6 +75,8 @@ def get_serial():
class BaseContext(object): class BaseContext(object):
placed = False # if we reciwed password or net commands
@staticmethod @staticmethod
def change_password(password): def change_password(password):
pass pass
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
# Notify user about vm expiring # Notify user about vm expiring
## ##
import cookielib
import errno import errno
import json import json
import logging import logging
...@@ -13,8 +12,27 @@ import multiprocessing ...@@ -13,8 +12,27 @@ import multiprocessing
import os import os
import platform import platform
import subprocess import subprocess
import urllib2
from urlparse import urlsplit try:
import cookielib
except ImportError:
import http.cookiejar as cookielib
try:
import urllib2
except ImportError:
import urllib.request as urllib2
try:
from urlparse import urlsplit
except ImportError:
from urllib.parse import urlsplit
try:
# Python 2: "unicode" is built-in
unicode
except NameError:
unicode = str
logger = logging.getLogger() logger = logging.getLogger()
logger.debug("notify imported") logger.debug("notify imported")
...@@ -58,7 +76,7 @@ def accept(): ...@@ -58,7 +76,7 @@ def accept():
from pytz import UTC from pytz import UTC
file_path = os.path.join(get_temp_dir(), file_name) file_path = os.path.join(get_temp_dir(), file_name)
if not os.path.isfile(file_path): if not os.path.isfile(file_path):
print "There is no recent notification to accept." print("There is no recent notification to accept.")
return False return False
# Load the saved url # Load the saved url
...@@ -83,10 +101,10 @@ def accept(): ...@@ -83,10 +101,10 @@ def accept():
new_local_time = parsed_time.astimezone( new_local_time = parsed_time.astimezone(
get_localzone()).strftime("%Y-%m-%d %H:%M:%S") get_localzone()).strftime("%Y-%m-%d %H:%M:%S")
except ValueError as e: except ValueError as e:
print "Parsing time failed: %s" % e print("Parsing time failed: %s" % e)
except Exception as e: except Exception as e:
print e print(e)
print "Renewal failed. Please try it manually at %s" % url print("Renewal failed. Please try it manually at %s" % url)
logger.exception("renew failed") logger.exception("renew failed")
return False return False
else: else:
...@@ -144,6 +162,7 @@ def open_in_browser(url): ...@@ -144,6 +162,7 @@ def open_in_browser(url):
def mount_smb(url): def mount_smb(url):
data = urlsplit(url) data = urlsplit(url)
share = data.path.lstrip('/') share = data.path.lstrip('/')
print("host: %s share %s user: %s pw: %s" % (data.hostname, share, data.username, data.password))
subprocess.call(('net', 'use', 'Z:', '/delete')) subprocess.call(('net', 'use', 'Z:', '/delete'))
try: try:
p = subprocess.Popen(( p = subprocess.Popen((
...@@ -229,8 +248,19 @@ if win: ...@@ -229,8 +248,19 @@ if win:
class SubProtocol(basic.LineReceiver): class SubProtocol(basic.LineReceiver):
def connectionMade(self):
logger.info("Subclient connected: %s", unicode(self))
clients.add(self)
def connectionLost(self, reason):
logger.info("Subclient disconnected: %s", unicode(self))
clients.remove(self)
def lineReceived(self, line): def lineReceived(self, line):
print "received", line logger.debug("received %s %s" % (line, type(line)))
if not isinstance(line, str):
line = line.decode()
if line.startswith('cifs://'): if line.startswith('cifs://'):
mount_smb(line) mount_smb(line)
else: else:
...@@ -243,7 +273,7 @@ if win: ...@@ -243,7 +273,7 @@ if win:
def run_client(): def run_client():
from twisted.internet import reactor from twisted.internet import reactor
print "connect to localhost:%d" % port print("connect to localhost:%d" % port)
reactor.connectTCP("localhost", port, SubFactory()) reactor.connectTCP("localhost", port, SubFactory())
reactor.run() reactor.run()
......
...@@ -3,6 +3,11 @@ import json ...@@ -3,6 +3,11 @@ import json
import logging import logging
import platform import platform
try:
# Python 2: "unicode" is built-in
unicode
except NameError:
unicode = str
logger = logging.getLogger() logger = logging.getLogger()
system = platform.system() system = platform.system()
...@@ -13,12 +18,13 @@ class SerialLineReceiverBase(LineReceiver, object): ...@@ -13,12 +18,13 @@ class SerialLineReceiverBase(LineReceiver, object):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if system == "FreeBSD": if system == "FreeBSD":
self.delimiter = '\n' self.delimiter = b'\n'
else: else:
self.delimiter = '\r' self.delimiter = b'\r'
super(SerialLineReceiverBase, self).__init__(*args, **kwargs) super(SerialLineReceiverBase, self).__init__(*args, **kwargs)
def send_response(self, response, args): def send_response(self, response, args):
# logger.debug("send_response %s %s" % (response, args))
self.transport.write(json.dumps({'response': response, self.transport.write(json.dumps({'response': response,
'args': args}) + '\r\n') 'args': args}) + '\r\n')
...@@ -33,6 +39,11 @@ class SerialLineReceiverBase(LineReceiver, object): ...@@ -33,6 +39,11 @@ class SerialLineReceiverBase(LineReceiver, object):
raise NotImplementedError("Subclass must implement abstract method") raise NotImplementedError("Subclass must implement abstract method")
def lineReceived(self, data): def lineReceived(self, data):
logger.debug("lineReceived: %s", data)
if (isinstance(data, unicode)):
data = data.strip('\0')
else:
data = data.strip(b'\0')
try: try:
data = json.loads(data) data = json.loads(data)
args = data.get('args', {}) args = data.get('args', {})
...@@ -43,6 +54,7 @@ class SerialLineReceiverBase(LineReceiver, object): ...@@ -43,6 +54,7 @@ class SerialLineReceiverBase(LineReceiver, object):
logger.debug('[serial] valid json: %s' % (data, )) logger.debug('[serial] valid json: %s' % (data, ))
except (ValueError, KeyError) as e: except (ValueError, KeyError) as e:
logger.error('[serial] invalid json: %s (%s)' % (data, e)) logger.error('[serial] invalid json: %s (%s)' % (data, e))
self.clearLineBuffer()
return return
if command is not None and isinstance(command, unicode): if command is not None and isinstance(command, unicode):
...@@ -50,7 +62,8 @@ class SerialLineReceiverBase(LineReceiver, object): ...@@ -50,7 +62,8 @@ class SerialLineReceiverBase(LineReceiver, object):
try: try:
self.handle_command(command, args) self.handle_command(command, args)
except Exception as e: except Exception as e:
logger.exception(u'Unhandled exception: ') logger.exception("Unhandled exception during line recived: ")
elif response is not None and isinstance(response, unicode): elif response is not None and isinstance(response, unicode):
logger.debug('received reply: %s (%s)' % (response, args)) logger.debug('received reply: %s (%s)' % (response, args))
self.clearLineBuffer()
self.handle_response(response, args) self.handle_response(response, args)
...@@ -65,7 +65,7 @@ class AppServerSvc (win32serviceutil.ServiceFramework): ...@@ -65,7 +65,7 @@ class AppServerSvc (win32serviceutil.ServiceFramework):
def main(): def main():
if len(sys.argv) == 1: if len(sys.argv) == 0: # never happen set 1 for debugging!!!!
# service must be starting... # service must be starting...
# for the sake of debugging etc, we use win32traceutil to see # for the sake of debugging etc, we use win32traceutil to see
# any unhandled exceptions and print statements. # any unhandled exceptions and print statements.
......
...@@ -7,8 +7,8 @@ from os.path import join ...@@ -7,8 +7,8 @@ from os.path import join
import logging import logging
import tarfile import tarfile
from StringIO import StringIO from io import StringIO
from base64 import decodestring from base64 import b64decode
from hashlib import md5 from hashlib import md5
from datetime import datetime from datetime import datetime
import win32api import win32api
...@@ -20,6 +20,11 @@ from twisted.internet import reactor ...@@ -20,6 +20,11 @@ from twisted.internet import reactor
from .network import change_ip_windows from .network import change_ip_windows
from context import BaseContext from context import BaseContext
try:
# Python 2: "unicode" is built-in
unicode
except NameError:
unicode = str
logger = logging.getLogger() logger = logging.getLogger()
...@@ -28,6 +33,7 @@ class Context(BaseContext): ...@@ -28,6 +33,7 @@ class Context(BaseContext):
@staticmethod @staticmethod
def change_password(password): def change_password(password):
BaseContext.placed = True
from win32com import adsi from win32com import adsi
ads_obj = adsi.ADsGetObject('WinNT://localhost/%s,user' % 'cloud') ads_obj = adsi.ADsGetObject('WinNT://localhost/%s,user' % 'cloud')
ads_obj.Getinfo() ads_obj.Getinfo()
...@@ -39,6 +45,7 @@ class Context(BaseContext): ...@@ -39,6 +45,7 @@ class Context(BaseContext):
@staticmethod @staticmethod
def change_ip(interfaces, dns): def change_ip(interfaces, dns):
BaseContext.placed = True
nameservers = dns.replace(' ', '').split(',') nameservers = dns.replace(' ', '').split(',')
change_ip_windows(interfaces, nameservers) change_ip_windows(interfaces, nameservers)
...@@ -111,7 +118,7 @@ class Context(BaseContext): ...@@ -111,7 +118,7 @@ class Context(BaseContext):
local_checksum = md5(data).hexdigest() local_checksum = md5(data).hexdigest()
if local_checksum != checksum: if local_checksum != checksum:
raise Exception("Checksum missmatch the file is damaged.") raise Exception("Checksum missmatch the file is damaged.")
decoded = StringIO(decodestring(data)) decoded = StringIO(b64decode(data))
try: try:
tar = tarfile.TarFile.open("dummy", fileobj=decoded, mode='r|gz') tar = tarfile.TarFile.open("dummy", fileobj=decoded, mode='r|gz')
tar.extractall(working_directory) tar.extractall(working_directory)
......
...@@ -7,12 +7,14 @@ Serial port support for Windows. ...@@ -7,12 +7,14 @@ Serial port support for Windows.
Requires PySerial and pywin32. Requires PySerial and pywin32.
""" """
# system imports
import win32file import win32file
import win32event import win32event
import win32con import win32con
# system imports
from serial.serialutil import to_bytes # type: ignore[import]
from time import sleep
# twisted imports # twisted imports
from twisted.internet import abstract from twisted.internet import abstract
...@@ -27,9 +29,13 @@ class SerialPort(abstract.FileDescriptor): ...@@ -27,9 +29,13 @@ class SerialPort(abstract.FileDescriptor):
connected = 1 connected = 1
def __init__(self, protocol, deviceNameOrPortNumber, reactor): def __init__(self, protocol, deviceName, reactor):
self.initHard(protocol, deviceName, reactor)
self.initSoft(protocol, deviceName, reactor)
def initHard(self, protocol, deviceName, reactor):
self.hComPort = win32file.CreateFile( self.hComPort = win32file.CreateFile(
deviceNameOrPortNumber, deviceName,
win32con.GENERIC_READ | win32con.GENERIC_WRITE, win32con.GENERIC_READ | win32con.GENERIC_WRITE,
0, # exclusive access 0, # exclusive access
None, # no security None, # no security
...@@ -38,101 +44,103 @@ class SerialPort(abstract.FileDescriptor): ...@@ -38,101 +44,103 @@ class SerialPort(abstract.FileDescriptor):
0) 0)
self.reactor = reactor self.reactor = reactor
self.protocol = protocol self.protocol = protocol
self.outQueue = [] self.deviceName = deviceName
self.closed = 0
self.closedNotifies = 0
self.writeInProgress = 0
self.protocol = protocol
self._overlappedRead = win32file.OVERLAPPED() self._overlappedRead = win32file.OVERLAPPED()
self._overlappedRead.hEvent = win32event.CreateEvent(None, 1, 0, None) self._overlappedRead.hEvent = win32event.CreateEvent(None, 1, 0, None)
self._overlappedWrite = win32file.OVERLAPPED() self._overlappedWrite = win32file.OVERLAPPED()
self._overlappedWrite.hEvent = win32event.CreateEvent(None, 0, 0, None) self._overlappedWrite.hEvent = win32event.CreateEvent(None, 0, 0, None)
self.reactor.addEvent(self._overlappedRead.hEvent, self, 'serialReadEvent')
self.reactor.addEvent(self._overlappedWrite.hEvent, self, 'serialWriteEvent')
self.reactor.addEvent( def initSoft(self, protocol, deviceName, reactor):
self._overlappedRead.hEvent, self.outQueue = []
self, self.closed = 0
'serialReadEvent') self.closedNotifies = 0
self.reactor.addEvent( self.writeInProgress = 0
self._overlappedWrite.hEvent, self.conneted = 1
self, self._reconnInProgress = False
'serialWriteEvent')
self.protocol.makeConnection(self) self.protocol.makeConnection(self)
self._finishPortSetup() self._startReading()
def _finishPortSetup(self): def _startReading(self, len=4096):
"""
Finish setting up the serial port.
This is a separate method to facilitate testing.
"""
rc, self.read_buf = win32file.ReadFile(self.hComPort, rc, self.read_buf = win32file.ReadFile(self.hComPort,
win32file.AllocateReadBuffer(1), win32file.AllocateReadBuffer(len),
self._overlappedRead) self._overlappedRead)
def serialReadEvent(self): def serialReadEvent(self):
# get that character we set up logger.debug("serialReadEvent %s %s" % (self._overlappedRead.Internal, self._overlappedRead.InternalHigh))
try: try:
n = win32file.GetOverlappedResult( n = win32file.GetOverlappedResult(self.hComPort, self._overlappedRead, 1)
self.hComPort, except Exception as e:
self._overlappedRead, logger.debug("Exception %s" % e)
0) sleep(10)
except Exception: logger.debug(self.connLost())
import time
time.sleep(10)
n = 0 n = 0
if n: if n > 0:
first = str(self.read_buf[:n]) # handle the received data:
# now we should get everything that is already in the buffer (max self.protocol.dataReceived(to_bytes(self.read_buf[:n]))
# 4096) # set up next read
win32event.ResetEvent(self._overlappedRead.hEvent) win32event.ResetEvent(self._overlappedRead.hEvent)
rc, buf = win32file.ReadFile(self.hComPort, self._startReading()
win32file.AllocateReadBuffer(4096),
self._overlappedRead)
n = win32file.GetOverlappedResult(
self.hComPort,
self._overlappedRead,
1)
# handle all the received data:
self.protocol.dataReceived(first + str(buf[:n]))
# set up next one
win32event.ResetEvent(self._overlappedRead.hEvent)
rc, self.read_buf = win32file.ReadFile(self.hComPort,
win32file.AllocateReadBuffer(1),
self._overlappedRead)
def write(self, data): def write(self, data):
if data: if data:
if isinstance(data, str):
data = str.encode(data)
if self.writeInProgress: if self.writeInProgress:
self.outQueue.append(data) self.outQueue.append(data)
logger.debug("added to queue") logger.debug("added to queue")
else: else:
self.writeInProgress = 1 self.writeInProgress = 1
win32file.WriteFile(self.hComPort, data, self._overlappedWrite) ret, n = win32file.WriteFile(self.hComPort, data, self._overlappedWrite)
logger.debug("Writed to file") logger.debug("Writed to file %s", ret)
def serialWriteEvent(self): def serialWriteEvent(self):
logger.debug("serialWriteEvent %s %s" % (self._overlappedWrite.Internal, self._overlappedWrite.InternalHigh))
if self._overlappedWrite.Internal < 0 and self._overlappedWrite.InternalHigh == 0 : # DANGER: Not documented variables
logger.debug(self.connLost())
self.writeInProgress = 0
return
try: try:
dataToWrite = self.outQueue.pop(0) dataToWrite = self.outQueue.pop(0)
except IndexError: except IndexError:
self.writeInProgress = 0 self.writeInProgress = 0
return return
else: else:
win32file.WriteFile( win32file.WriteFile(self.hComPort, dataToWrite, self._overlappedWrite)
self.hComPort,
dataToWrite, def connLost(self):
self._overlappedWrite) if self._reconnInProgress:
return None
self._reconnInProgress = True
return self.reactor.callLater(30, self.connectionLostEvent, self)
def connectionLostEvent(self, reason):
abstract.FileDescriptor.connectionLost(self, reason)
self.protocol.connectionLost(reason)
logger.debug("Reconecting after 30s")
# sleep(30)
self.initSoft(self.protocol, self.deviceName, self.reactor)
def connectionLost(self, reason):
def connectionLost(self, reason=None):
""" """
Called when the serial port disconnects. Called when the serial port disconnects.
Will call C{connectionLost} on the protocol that is handling the Will call C{connectionLost} on the protocol that is handling the
serial data. serial data.
""" """
# import pdb; pdb.set_trace()
win32file.CancelIo(self.hComPort)
self.reactor.removeEvent(self._overlappedRead.hEvent) self.reactor.removeEvent(self._overlappedRead.hEvent)
self.reactor.removeEvent(self._overlappedWrite.hEvent) self.reactor.removeEvent(self._overlappedWrite.hEvent)
win32file.CloseHandle(self._overlappedRead.hEvent)
win32file.CloseHandle(self._overlappedWrite.hEvent)
abstract.FileDescriptor.connectionLost(self, reason) abstract.FileDescriptor.connectionLost(self, reason)
win32file.CloseHandle(self.hComPort) win32file.CloseHandle(self.hComPort)
self.protocol.connectionLost(reason) self.protocol.connectionLost(reason)
logger.debug("Hard reconecting after 10s")
sleep(10)
self.initHard(self.protocol, self.deviceName, self.reactor)
self.initSoft(self.protocol, self.deviceName, self.reactor)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment