ssh.py 2.92 KB
Newer Older
1 2 3
from base64 import decodestring
from struct import unpack
import binascii
4
import unittest
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59


class InvalidKeyType(Exception):
    pass


class InvalidKey(Exception):
    pass


class PubKey(object):

    key_types = ('ssh-rsa', 'ssh-dsa', 'ssh-ecdsa')

    # http://stackoverflow.com/questions/2494450/ssh-rsa-public-key-
    # validation-using-a-regular-expression
    @classmethod
    def validate_key(cls, key_type, key):
        try:
            data = decodestring(key)
        except binascii.Error:
            raise InvalidKey()
        int_len = 4
        str_len = unpack('>I', data[:int_len])[0]
        if data[int_len:int_len + str_len] != key_type:
            raise InvalidKey()

    def __init__(self, key_type, key, comment):
        if key_type not in self.key_types:
            raise InvalidKeyType()
        self.key_type = key_type

        PubKey.validate_key(key_type, key)
        self.key = key

        self.comment = unicode(comment)

    def __hash__(self):
        return hash(frozenset(self.__dict__.items()))

    def __eq__(self, other):
        return self.__dict__ == other.__dict__

    @classmethod
    def from_str(cls, line):
        key_type, key, comment = line.split()
        return PubKey(key_type, key, comment)

    def __unicode__(self):
        return u' '.join((self.key_type, self.key, self.comment))

    def __repr__(self):
        return u'<PubKey: %s>' % unicode(self)


60
# Unit tests
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104

class SshTestCase(unittest.TestCase):
    def setUp(self):
        self.p1 = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EA comment')
        self.p2 = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EA comment')
        self.p3 = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EC comment')

    def test_invalid_key_type(self):
        self.assertRaises(InvalidKeyType, PubKey, 'ssh-inv', 'x', 'comment')

    def test_valid_key(self):
        PubKey('ssh-rsa', 'AAAAB3NzaC1yc2EA', 'comment')

    def test_invalid_key(self):
        self.assertRaises(InvalidKey, PubKey, 'ssh-rsa', 'x', 'comment')

    def test_invalid_key2(self):
        self.assertRaises(InvalidKey, PubKey, 'ssh-rsa',
                          'AAAAB3MzaC1yc2EA', 'comment')

    def test_repr(self):
        p = PubKey('ssh-rsa', 'AAAAB3NzaC1yc2EA', 'comment')
        self.assertEqual(
            repr(p), '<PubKey: ssh-rsa AAAAB3NzaC1yc2EA comment>')

    def test_unicode(self):
        p = PubKey('ssh-rsa', 'AAAAB3NzaC1yc2EA', 'comment')
        self.assertEqual(unicode(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')

    def test_from_str(self):
        p = PubKey.from_str('ssh-rsa AAAAB3NzaC1yc2EA comment')
        self.assertEqual(unicode(p), 'ssh-rsa AAAAB3NzaC1yc2EA comment')

    def test_eq(self):
        self.assertEqual(self.p1, self.p2)
        self.assertNotEqual(self.p1, self.p3)

    def test_hash(self):
        s = set()
        s.add(self.p1)
        s.add(self.p2)
        s.add(self.p3)
        self.assertEqual(len(s), 2)

105

106 107
if __name__ == '__main__':
    unittest.main()