Bugfixes
[re6stnet.git] / re6st-registry
1 #!/usr/bin/env python
2 import random, select, smtplib, sqlite3, string, socket
3 import subprocess, time, threading, traceback, errno, logging, os, xmlrpclib
4 from SimpleXMLRPCServer import SimpleXMLRPCServer, SimpleXMLRPCRequestHandler
5 from email.mime.text import MIMEText
6 from OpenSSL import crypto
7 from re6st  import utils
8
9 # To generate server ca and key with serial for 2001:db8:42::/48
10 # openssl req -nodes -new -x509 -key ca.key -set_serial 0x120010db80042 -days 365 -out ca.crt
11
12 IPV6_V6ONLY = 26
13 SOL_IPV6 = 41
14
15
16 class RequestHandler(SimpleXMLRPCRequestHandler):
17
18     def address_string(self):
19         # Workaround for http://bugs.python.org/issue6085
20         return self.client_address[0]
21
22     def _dispatch(self, method, params):
23         logging.debug('%s%r', method, params)
24         return self.server._dispatch(method, (self,) + params)
25
26 class SimpleXMLRPCServer4(SimpleXMLRPCServer):
27
28     allow_reuse_address = True
29
30
31 class SimpleXMLRPCServer6(SimpleXMLRPCServer4):
32
33     address_family = socket.AF_INET6
34
35     def server_bind(self):
36         self.socket.setsockopt(SOL_IPV6, IPV6_V6ONLY, 1)
37         SimpleXMLRPCServer4.server_bind(self)
38
39
40 class main(object):
41
42     def __init__(self):
43         self.cert_duration = 365 * 86400
44         self.time_out = 45000
45         self.refresh_interval = 600
46         self.last_refresh = time.time()
47
48         utils.setupLog(3)
49
50         # Command line parsing
51         parser = utils.ArgParser(fromfile_prefix_chars='@',
52                 description='Peer discovery http server for re6stnet')
53         _ = parser.add_argument
54         _('--port', type=int, default=80, help='Port of the host server')
55         _('--db', required=True,
56                 help='Path to database file')
57         _('--ca', required=True,
58                 help='Path to ca.crt file')
59         _('--key', required=True,
60                 help='Path to certificate key')
61         _('--mailhost', required=True,
62                 help='SMTP server mail host')
63         _('--private',
64                 help='VPN IP of the node on which runs the registry')
65         self.config = parser.parse_args()
66
67         if not self.config.private:
68             logging.warning('You have declared no private address'
69                     ', either this is the first start, or you should'
70                     'check you configuration')
71
72         # Database initializing
73         self.db = sqlite3.connect(self.config.db, isolation_level=None)
74         self.db.execute("""CREATE TABLE IF NOT EXISTS peers (
75                         prefix text primary key not null,
76                         address text not null,
77                         date integer default (strftime('%s','now')))""")
78         self.db.execute("CREATE INDEX IF NOT EXISTS peers_ping ON peers(date)")
79         self.db.execute("""CREATE TABLE IF NOT EXISTS token (
80                         token text primary key not null,
81                         email text not null,
82                         prefix_len integer not null,
83                         date integer not null)""")
84         try:
85             self.db.execute("""CREATE TABLE cert (
86                                prefix text primary key not null,
87                                email text,
88                                cert text)""")
89         except sqlite3.OperationalError, e:
90             if e.args[0] != 'table cert already exists':
91                 raise RuntimeError
92         else:
93             self.db.execute("INSERT INTO cert VALUES ('',null,null)")
94
95         # Loading certificates
96         with open(self.config.ca) as f:
97             self.ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
98         with open(self.config.key) as f:
99             self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
100         # Get vpn network prefix
101         self.network = bin(self.ca.get_serial_number())[3:]
102         logging.info("Network prefix : %s/%u" % (self.network, len(self.network)))
103
104         # Starting server
105         server4 = SimpleXMLRPCServer4(('0.0.0.0', self.config.port), requestHandler=RequestHandler, allow_none=True)
106         server4.register_instance(self)
107         server6 = SimpleXMLRPCServer6(('::', self.config.port), requestHandler=RequestHandler, allow_none=True)
108         server6.register_instance(self)
109
110         # Main loop
111         while True:
112             try:
113                 r, w, e = select.select([server4, server6], [], [])
114             except (OSError, select.error) as e:
115                 if e.args[0] != errno.EINTR:
116                     raise
117             else:
118                 for r in r:
119                     r._handle_request_noblock()
120
121     def requestToken(self, handler, email):
122         while True:
123             # Generating token
124             token = ''.join(random.sample(string.ascii_lowercase, 8))
125             # Updating database
126             try:
127                 self.db.execute("INSERT INTO token VALUES (?,?,?,?)", (token, email, 16, int(time.time())))
128                 break
129             except sqlite3.IntegrityError:
130                 pass
131
132         # Creating and sending email
133         s = smtplib.SMTP(self.config.mailhost)
134         me = 'postmaster@re6st.net'
135         msg = MIMEText('Hello world !\nYour token : %s' % (token,))  # XXX
136         msg['Subject'] = '[re6stnet] Token Request'
137         msg['From'] = me
138         msg['To'] = email
139         s.sendmail(me, email, msg.as_string())
140         s.quit()
141
142     def _getPrefix(self, prefix_len):
143         max_len = 128 - len(self.network)
144         assert 0 < prefix_len <= max_len
145         try:
146             prefix, = self.db.execute("""SELECT prefix FROM cert WHERE length(prefix) <= ? AND cert is null
147                                          ORDER BY length(prefix) DESC""", (prefix_len,)).next()
148         except StopIteration:
149             logging.error('There are no more free /%s prefix available' % (prefix_len,))
150             raise
151         while len(prefix) < prefix_len:
152             self.db.execute("UPDATE cert SET prefix = ? WHERE prefix = ?", (prefix + '1', prefix))
153             prefix += '0'
154             self.db.execute("INSERT INTO cert VALUES (?,null,null)", (prefix,))
155         if len(prefix) < max_len or '1' in prefix:
156             return prefix
157         self.db.execute("UPDATE cert SET cert = 'reserved' WHERE prefix = ?", (prefix,))
158         return self._getPrefix(prefix_len)
159
160     def requestCertificate(self, handler, token, cert_req):
161         try:
162             req = crypto.load_certificate_request(crypto.FILETYPE_PEM, cert_req)
163             with self.db:
164                 try:
165                     token, email, prefix_len, _ = self.db.execute("SELECT * FROM token WHERE token = ?", (token,)).next()
166                 except StopIteration:
167                     logging.exception('Bad token (%s) in request' % (token,))
168                     raise
169                 self.db.execute("DELETE FROM token WHERE token = ?", (token,))
170
171                 # Get a new prefix
172                 prefix = self._getPrefix(prefix_len)
173
174                 # Create certificate
175                 cert = crypto.X509()
176                 #cert.set_serial_number(serial)
177                 cert.gmtime_adj_notBefore(0)
178                 cert.gmtime_adj_notAfter(self.cert_duration)
179                 cert.set_issuer(self.ca.get_subject())
180                 subject = req.get_subject()
181                 subject.CN = "%u/%u" % (int(prefix, 2), prefix_len)
182                 cert.set_subject(subject)
183                 cert.set_pubkey(req.get_pubkey())
184                 cert.sign(self.key, 'sha1')
185                 cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
186
187                 # Insert certificate into db
188                 self.db.execute("UPDATE cert SET email = ?, cert = ? WHERE prefix = ?", (email, cert, prefix))
189
190             return cert
191         except:
192             traceback.print_exc()
193             raise
194
195     def getCa(self, handler):
196         return crypto.dump_certificate(crypto.FILETYPE_PEM, self.ca)
197
198     def getPrivateAddress(self, handler):
199         return 'http://[%s]:%u' % (self.config.private, self.config.port)
200
201     def getBootstrapPeer(self, handler, client_prefix):
202         cert, = self.db.execute("SELECT cert FROM cert WHERE prefix = ?",
203                 (client_prefix,)).next()
204         try:
205             prefix, address = self.db.execute("""SELECT prefix, address FROM peers
206                 WHERE prefix != ? ORDER BY random() LIMIT 1""", (client_prefix,)).next()
207         except StopIteration:
208             logging.info('No peer to send for bootstrap')
209             raise
210         r, w = os.pipe()
211         try:
212             threading.Thread(target=os.write, args=(w, cert)).start()
213             p = subprocess.Popen(('openssl', 'rsautl', '-encrypt', '-certin', '-inkey', '/proc/self/fd/%u' % r),
214                 stdin=subprocess.PIPE, stdout=subprocess.PIPE)
215             logging.info("Sending bootstrap peer (%s, %s)" % (prefix, address))
216             return xmlrpclib.Binary(p.communicate('%s %s' % (prefix, address))[0])
217         finally:
218             os.close(r)
219             os.close(w)
220
221     def declare(self, handler, address):
222         client_address, _, _, _ = handler.client_address
223         client_ip = utils.binFromIp(client_address)
224         if client_ip.startswith(self.network):
225             prefix = client_ip[len(self.network):]
226             prefix, = self.db.execute("SELECT prefix FROM cert WHERE prefix <= ? ORDER BY prefix DESC LIMIT 1", (prefix,)).next()
227             self.db.execute("INSERT OR REPLACE INTO peers (prefix, address) VALUES (?,?)", (prefix, address))
228             return True
229         else:
230             logging.warning("Unauthorized connection from %s which does not start with %s"
231                     % (utils.ipFromBin(client_ip), utils.ipFromBin(self.network.ljust(128, '0'))))
232             return False
233
234 if __name__ == "__main__":
235     main()