from base64 import decodestring
from struct import unpack
import binascii


class InvalidKeyType(Exception):
    pass


class InvalidKey(Exception):
    pass


class PubKey(object):

    key_types = ('ssh-rsa', 'ssh-dsa', 'ssh-ecdsa')

    # http://stackoverflow.com/questions/2494450/ssh-rsa-public-key-
    # validation-using-a-regular-expression
    @classmethod
    def validate_key(cls, key_type, key):
        try:
            data = decodestring(key)
        except binascii.Error:
            raise InvalidKey()
        int_len = 4
        str_len = unpack('>I', data[:int_len])[0]
        if data[int_len:int_len + str_len] != key_type:
            raise InvalidKey()

    def __init__(self, key_type, key, comment):
        if key_type not in self.key_types:
            raise InvalidKeyType()
        self.key_type = key_type

        PubKey.validate_key(key_type, key)
        self.key = key

        self.comment = unicode(comment)

    def __hash__(self):
        return hash(frozenset(self.__dict__.items()))

    def __eq__(self, other):
        return self.__dict__ == other.__dict__

    @classmethod
    def from_str(cls, line):
        key_type, key, comment = line.split()
        return PubKey(key_type, key, comment)

    def __unicode__(self):
        return u' '.join((self.key_type, self.key, self.comment))

    def __repr__(self):
        return u'<PubKey: %s>' % unicode(self)


import unittest


class SshTestCase(unittest.TestCase):
    def setUp(self):
        self.p1 = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EA comment')
        self.p2 = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EA comment')
        self.p3 = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EC comment')

    def test_invalid_key_type(self):
        self.assertRaises(InvalidKeyType, PubKey, 'ssh-inv', 'x', 'comment')

    def test_valid_key(self):
        PubKey('ssh-rsa', 'AAAAB3NzaC1yc2EA', 'comment')

    def test_invalid_key(self):
        self.assertRaises(InvalidKey, PubKey, 'ssh-rsa', 'x', 'comment')

    def test_invalid_key2(self):
        self.assertRaises(InvalidKey, PubKey, 'ssh-rsa',
                          'AAAAB3MzaC1yc2EA', 'comment')

    def test_repr(self):
        p = PubKey('ssh-rsa', 'AAAAB3NzaC1yc2EA', 'comment')
        self.assertEqual(
            repr(p), '<PubKey: ssh-rsa AAAAB3NzaC1yc2EA comment>')

    def test_unicode(self):
        p = PubKey('ssh-rsa', 'AAAAB3NzaC1yc2EA', 'comment')
        self.assertEqual(unicode(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')

    def test_from_str(self):
        p = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EA comment')
        self.assertEqual(unicode(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')

    def test_eq(self):
        self.assertEqual(self.p1, self.p2)
        self.assertNotEqual(self.p1, self.p3)

    def test_hash(self):
        s = set()
        s.add(self.p1)
        s.add(self.p2)
        s.add(self.p3)
        self.assertEqual(len(s), 2)

if __name__ == '__main__':
    unittest.main()