#!/usr/bin/env python

import sys, os, socket, time, asyncore, asynchat, sha256, srp6a, aes, gcm

# This number is known to be prime, and 2 is one of its primitive roots.
N = 125617018995153554710546479714086468244499594888726646874671447258204721048803L
g = 2
srp = srp6a.SRP(N, g, sha256.hash)

class ServerError(Exception): pass

class Storage:
    def __init__(self, path):
        self.path = path

    def exists(self, name):
        path = os.path.join(self.path, name + '.pu')
        return os.path.exists(path)

    def read(self, name):
        path = os.path.join(self.path, name + '.pu')
        file = open(path)
        data = file.read()
        file.close()
        return data

    def write(self, name, data):
        path = os.path.join(self.path, name + '.pu')
        file = open(path + '.tmp', 'w')
        file.write(data)
        file.close()
        os.rename(path + '.tmp', path)

    def create(self, name):
        path = os.path.join(self.path, name + '.pu')
        try:
            fd = os.open(path, os.O_CREAT | os.O_EXCL)
            os.close(fd)
            return True
        except OSError:
            return False

    def remove(self, name):
        path = os.path.join(self.path, name + '.pu')
        try:
            os.remove(path)
            return True
        except OSError:
            return False

    def list(self):
        for filename in os.listdir(self.path):
            if filename.endswith('.pu'):
                yield filename[:-3]

class Listener(asyncore.dispatcher):
    def __init__(self, port, storage):
        asyncore.dispatcher.__init__(self)
        self.storage = storage
        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
        self.set_reuse_addr()
        self.bind(('', port))
        self.listen(12)

    def tick(self):
        pass

    def handle_accept(self):
        Request(self.accept(), self.storage, timeout=60)
        
def get_username(username):
    for c in username:
        if c not in 'abcdefghijklmnopqrstuvwxyz0123456789_':
            raise ServerError('invalid username %r' % username)
    if not username:
        raise ServerError('invalid username %r' % username)
    return username

def get_int(line):
    try:
        return int(line)
    except ValueError:
        raise ServerError('invalid integer %r' % line)

class Request(asynchat.async_chat):
    id = 0

    def __init__(self, (conn, addr), storage, timeout=10):
        Request.id += 1
        self.id = '%d:' % Request.id
        print self.id, 'connection from', addr
        asynchat.async_chat.__init__(self, conn=conn)
        self.last_activity = time.time()
        self.storage = storage
        self.timeout = timeout
        self.request = ''
        self.inmode = None
        self.outmode = None
        self.engine = self.make_engine()
        self.engine_push = self.engine.next()
        self.set_terminator('\n')
        self.proceed()

    def tick(self):
        if time.time() - self.last_activity > self.timeout:
            self.close_when_done()

    def collect_incoming_data(self, data):
        self.last_activity = time.time()
        self.request += data

    def push_line(self, line):
        if self.outmode:
            print self.id, '=>', repr(line)
            data, mac = self.outmode.encrypt(line)
            line = '+' + ''.join(['%02x' % ord(c) for c in data + mac])
        print self.id, '->', line
        self.push(line + '\n')

    def found_terminator(self):
        line = self.request.strip()
        print self.id, '<-', line
        if self.inmode:
            try:
                data = ''.join([chr(int(line[i:i+2], 16))
                                for i in range(0, len(line), 2)])
                line = self.inmode.decrypt(data[:-16], data[-16:])
                print self.id, '<=', repr(line)
            except ValueError:
                # Special case: abort if encrypted message is invalid.
                self.inmode = self.outmode = None
                self.discard_buffers()
                e = ServerError('protocol encryption')
                print self.id, '[error]', e.args[0]
                self.push_line('-%s' % e.args[0])
                self.close_when_done()
                return
        try:
            self.engine_push(line)
        except ServerError, e:
            self.discard_buffers()
            print self.id, '[error]', e.args[0]
            self.push_line('-%s' % e.args[0])
            self.close_when_done()
            return
        except:
            self.discard_buffers()
            exception = sys.exc_info()[1]
            print self.id, '[panic]', exception
            self.push_line('!%s' % exception)
            self.close_when_done()
            return
        self.request = ''
        self.proceed()

    def handle_close(self):
        print self.id, '[client closed]'
        self.close()

    def proceed(self):
        try:
            for message in self.engine:
                if message is None:
                    # Return to wait for more input.
                    return
                self.push_line('+%s' % message)
            print self.id, '[done]'
        except ServerError, e:
            print self.id, '[error]', e.args[0]
            self.push_line('-%s' % e.args[0])
        except:
            exception = sys.exc_info()[1]
            print self.id, '[panic]', exception
            self.push_line('!%s' % exception)
        # Either an error occurred or the engine is exhausted.
        self.close_when_done()

    def make_engine(self):
        # First emit the function by which we receive incoming messages.
        queue = []
        dispatch_push = queue.append
        def push(value):
            dispatch_push(value)
        yield push

        yield 'Passpet (passpetd/0.1 Python/%s)' % sys.version.split()[0]

        # Make sure the protocol version matches.
        yield None; protocol = queue.pop(0)
        if protocol != '1':
            raise ServerError('protocol unsupported %r' % protocol)
        yield ''

        # Dispatch on the command given on the first line.
        yield None; command = queue.pop(0)
        if command == 'create':
            yield None; username = get_username(queue.pop(0))
            yield None; k1 = get_int(queue.pop(0))
            yield None; salt = get_int(queue.pop(0))
            yield None; verifier = get_int(queue.pop(0))
            index = 0
            while not self.storage.create('%s-%d' % (username, index)):
                index += 1
            data = '%s\n%s\n%s\n' % (k1, salt, verifier)
            self.storage.write('%s-%d' % (username, index), data)
            print self.id, 'create', username, index
            yield str(index)

        elif command == 'list':
            yield None; username = get_username(queue.pop(0))
            print self.id, 'list', username
            items = []
            for name in self.storage.list():
                if name.startswith(username + '-'):
                    index = name[len(username) + 1:]
                    items.append((int(index), name))
            items.sort()
            results = []
            for index, name in items:
                data = self.storage.read(name)
                k1 = int(data.split('\n')[0])
                results.append('%d:%d' % (index, k1))
            yield ' '.join(results)

        elif command == 'login':
            # The username and index select a login account.
            yield None; username = get_username(queue.pop(0))
            yield None; index = get_int(queue.pop(0))
            print self.id, 'login', username, index
            if not self.storage.exists('%s-%d' % (username, index)):
                raise ServerError('invalid account %s-%d' % (username, index))
            data = self.storage.read('%s-%d' % (username, index))
            k1, salt, verifier, file = data.split('\n', 3)
            k1, salt, verifier = int(k1), int(salt), int(verifier)

            # Hook up this channel to the SRP authenticator.
            try:
                R = srp6a.Object()
                server = srp.authenticate(username, salt, verifier, R)
                server_push = server.next()
                dispatch_push = lambda line: server_push(get_int(line))
                for message in server:
                    yield message
            except srp6a.SRPError, e:
                raise ServerError('login %s' % e.args[0])
            dispatch_push = queue.append

            # Login completed; send a nonce to begin the encrypted session.
            cipher = aes.AES(128, R.key)
            noncebytes = map(ord, os.urandom(16))
            yield ''.join(['%02x' % b for b in noncebytes])

            # The rest of the session is encrypted.
            nonce = aes.tolong(noncebytes)
            self.inmode = gcm.GCM(cipher, nonce=nonce ^ 1)
            self.outmode = gcm.GCM(cipher, nonce=nonce)

            yield None; command = queue.pop(0)
            if command == 'delete':
                print self.id, 'delete'
                self.storage.remove('%s-%d' % (username, index))
                yield ''
            elif command == 'read':
                print self.id, 'read'
                yield file
            elif command == 'write':
                yield None; old_mac = queue.pop(0)
                yield None; new_file = queue.pop(0)
                print self.id, 'write', repr(old_mac), repr(new_file)
                if file[-16:] == old_mac:
                    data = '%s\n%s\n%s\n%s' % (k1, salt, verifier, new_file)
                    self.storage.write('%s-%d' % (username, index), data)
                    yield ''
                else:
                    raise ServerError('write MAC incorrect')
            else:
                raise ServerError('invalid command %r' % command)
        else:
            raise ServerError('invalid command %r' % command)

args = sys.argv[1:] + [None, None]
port = args[0] and int(args[0]) or 7277
path = args[1] or '.'
Listener(port, Storage(path))
print 'Passpet Storage Server (port %d, path %s) is ready.' % (port, path)
while 1:
    asyncore.loop(timeout=1, count=1)
    for channel in asyncore.socket_map.values():
        channel.tick()
