Change protocol to discover addresses of peers to connect to
[re6stnet.git] / re6st-registry
1 #!/usr/bin/env python
2 import random, select, smtplib, sqlite3, string, socket
3 import subprocess, time, threading, traceback, errno, logging, os, xmlrpclib
4 from SimpleXMLRPCServer import SimpleXMLRPCServer, SimpleXMLRPCRequestHandler
5 from email.mime.text import MIMEText
6 from OpenSSL import crypto
7 from re6st import tunnel, utils
8
9 # To generate server ca and key with serial for 2001:db8:42::/48
10 # openssl req -nodes -new -x509 -key ca.key -set_serial 0x120010db80042 -days 365 -out ca.crt
11
12 IPV6_V6ONLY = 26
13 SOL_IPV6 = 41
14
15
16 class RequestHandler(SimpleXMLRPCRequestHandler):
17
18     def address_string(self):
19         # Workaround for http://bugs.python.org/issue6085
20         return self.client_address[0]
21
22     def _dispatch(self, method, params):
23         logging.debug('%s%r', method, params)
24         return self.server._dispatch(method, (self,) + params)
25
26 class SimpleXMLRPCServer4(SimpleXMLRPCServer):
27
28     allow_reuse_address = True
29
30
31 class SimpleXMLRPCServer6(SimpleXMLRPCServer4):
32
33     address_family = socket.AF_INET6
34
35     def server_bind(self):
36         self.socket.setsockopt(SOL_IPV6, IPV6_V6ONLY, 1)
37         SimpleXMLRPCServer4.server_bind(self)
38
39
40 class main(object):
41
42     def __init__(self):
43         self.cert_duration = 365 * 86400
44         self.time_out = 45000
45         self.refresh_interval = 600
46         self.last_refresh = time.time()
47
48         utils.setupLog(3)
49
50         # Command line parsing
51         parser = utils.ArgParser(fromfile_prefix_chars='@',
52                 description='Peer discovery http server for re6stnet')
53         _ = parser.add_argument
54         _('--port', type=int, default=80, help='Port of the host server')
55         _('--db', required=True,
56                 help='Path to database file')
57         _('--ca', required=True,
58                 help='Path to ca.crt file')
59         _('--key', required=True,
60                 help='Path to certificate key')
61         _('--mailhost', required=True,
62                 help='SMTP server mail host')
63         _('--private',
64                 help='VPN IP of the node on which runs the registry')
65         self.config = parser.parse_args()
66
67         if self.config.private:
68             self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
69         else:
70             logging.warning('You have declared no private address'
71                     ', either this is the first start, or you should'
72                     'check you configuration')
73
74         # Database initializing
75         self.db = sqlite3.connect(self.config.db, isolation_level=None)
76         self.db.execute("""CREATE TABLE IF NOT EXISTS token (
77                         token text primary key not null,
78                         email text not null,
79                         prefix_len integer not null,
80                         date integer not null)""")
81         try:
82             self.db.execute("""CREATE TABLE cert (
83                                prefix text primary key not null,
84                                email text,
85                                cert text)""")
86         except sqlite3.OperationalError, e:
87             if e.args[0] != 'table cert already exists':
88                 raise RuntimeError
89         else:
90             self.db.execute("INSERT INTO cert VALUES ('',null,null)")
91
92         # Loading certificates
93         with open(self.config.ca) as f:
94             self.ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
95         with open(self.config.key) as f:
96             self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
97         # Get vpn network prefix
98         self.network = bin(self.ca.get_serial_number())[3:]
99         logging.info("Network prefix : %s/%u" % (self.network, len(self.network)))
100
101         # Starting server
102         server4 = SimpleXMLRPCServer4(('0.0.0.0', self.config.port), requestHandler=RequestHandler, allow_none=True)
103         server4.register_instance(self)
104         server6 = SimpleXMLRPCServer6(('::', self.config.port), requestHandler=RequestHandler, allow_none=True)
105         server6.register_instance(self)
106
107         # Main loop
108         while True:
109             try:
110                 r, w, e = select.select([server4, server6], [], [])
111             except (OSError, select.error) as e:
112                 if e.args[0] != errno.EINTR:
113                     raise
114             else:
115                 for r in r:
116                     r._handle_request_noblock()
117
118     def requestToken(self, handler, email):
119         while True:
120             # Generating token
121             token = ''.join(random.sample(string.ascii_lowercase, 8))
122             # Updating database
123             try:
124                 self.db.execute("INSERT INTO token VALUES (?,?,?,?)", (token, email, 16, int(time.time())))
125                 break
126             except sqlite3.IntegrityError:
127                 pass
128
129         # Creating and sending email
130         s = smtplib.SMTP(self.config.mailhost)
131         me = 'postmaster@re6st.net'
132         msg = MIMEText('Hello world !\nYour token : %s' % (token,))  # XXX
133         msg['Subject'] = '[re6stnet] Token Request'
134         msg['From'] = me
135         msg['To'] = email
136         s.sendmail(me, email, msg.as_string())
137         s.quit()
138
139     def _getPrefix(self, prefix_len):
140         max_len = 128 - len(self.network)
141         assert 0 < prefix_len <= max_len
142         try:
143             prefix, = self.db.execute("""SELECT prefix FROM cert WHERE length(prefix) <= ? AND cert is null
144                                          ORDER BY length(prefix) DESC""", (prefix_len,)).next()
145         except StopIteration:
146             logging.error('There are no more free /%s prefix available' % (prefix_len,))
147             raise
148         while len(prefix) < prefix_len:
149             self.db.execute("UPDATE cert SET prefix = ? WHERE prefix = ?", (prefix + '1', prefix))
150             prefix += '0'
151             self.db.execute("INSERT INTO cert VALUES (?,null,null)", (prefix,))
152         if len(prefix) < max_len or '1' in prefix:
153             return prefix
154         self.db.execute("UPDATE cert SET cert = 'reserved' WHERE prefix = ?", (prefix,))
155         return self._getPrefix(prefix_len)
156
157     def requestCertificate(self, handler, token, cert_req):
158         try:
159             req = crypto.load_certificate_request(crypto.FILETYPE_PEM, cert_req)
160             with self.db:
161                 try:
162                     token, email, prefix_len, _ = self.db.execute("SELECT * FROM token WHERE token = ?", (token,)).next()
163                 except StopIteration:
164                     logging.exception('Bad token (%s) in request' % (token,))
165                     raise
166                 self.db.execute("DELETE FROM token WHERE token = ?", (token,))
167
168                 # Get a new prefix
169                 prefix = self._getPrefix(prefix_len)
170
171                 # Create certificate
172                 cert = crypto.X509()
173                 #cert.set_serial_number(serial)
174                 cert.gmtime_adj_notBefore(0)
175                 cert.gmtime_adj_notAfter(self.cert_duration)
176                 cert.set_issuer(self.ca.get_subject())
177                 subject = req.get_subject()
178                 subject.CN = "%u/%u" % (int(prefix, 2), prefix_len)
179                 cert.set_subject(subject)
180                 cert.set_pubkey(req.get_pubkey())
181                 cert.sign(self.key, 'sha1')
182                 cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
183
184                 # Insert certificate into db
185                 self.db.execute("UPDATE cert SET email = ?, cert = ? WHERE prefix = ?", (email, cert, prefix))
186
187             return cert
188         except:
189             traceback.print_exc()
190             raise
191
192     def getCa(self, handler):
193         return crypto.dump_certificate(crypto.FILETYPE_PEM, self.ca)
194
195     def getPrivateAddress(self, handler):
196         return self.config.private
197
198     def getBootstrapPeer(self, handler, client_prefix):
199         cert, = self.db.execute("SELECT cert FROM cert WHERE prefix = ?",
200                 (client_prefix,)).next()
201         address = self.config.private, tunnel.PORT
202         self.sock.sendto('\2', address)
203         peer = None
204         while select.select([self.sock], [], [], peer is None)[0]:
205             msg = self.sock.recv(1<<16)
206             if msg[0] == '\1':
207                 try:
208                     peer = msg[1:].split('\n')[-2]
209                 except IndexError:
210                     peer = ''
211         if peer is None:
212             raise EnvironmentError("Timeout while querying [%s]:%u", *address)
213         if not peer or peer.split()[0] == client_prefix:
214             raise LookupError("No bootstrap peer found")
215         logging.info("Sending bootstrap peer: %s", peer)
216         r, w = os.pipe()
217         try:
218             threading.Thread(target=os.write, args=(w, cert)).start()
219             p = subprocess.Popen(('openssl', 'rsautl', '-encrypt', '-certin', '-inkey', '/proc/self/fd/%u' % r),
220                 stdin=subprocess.PIPE, stdout=subprocess.PIPE)
221             return xmlrpclib.Binary(p.communicate(peer)[0])
222         finally:
223             os.close(r)
224             os.close(w)
225
226 if __name__ == "__main__":
227     main()