Allow the registry to know the topology, for debugging purpose
[re6stnet.git] / re6st-registry
1 #!/usr/bin/env python
2 import random, select, smtplib, sqlite3, string, socket
3 import subprocess, time, threading, traceback, errno, logging, os, xmlrpclib
4 from collections import deque
5 from SimpleXMLRPCServer import SimpleXMLRPCServer, SimpleXMLRPCRequestHandler
6 from email.mime.text import MIMEText
7 from OpenSSL import crypto
8 from re6st import tunnel, utils
9
10 # To generate server ca and key with serial for 2001:db8:42::/48
11 # openssl req -nodes -new -x509 -key ca.key -set_serial 0x120010db80042 -days 365 -out ca.crt
12
13 IPV6_V6ONLY = 26
14 SOL_IPV6 = 41
15
16
17 class RequestHandler(SimpleXMLRPCRequestHandler):
18
19     def address_string(self):
20         # Workaround for http://bugs.python.org/issue6085
21         return self.client_address[0]
22
23     def _dispatch(self, method, params):
24         logging.debug('%s%r', method, params)
25         return self.server._dispatch(method, (self,) + params)
26
27 class SimpleXMLRPCServer4(SimpleXMLRPCServer):
28
29     allow_reuse_address = True
30
31
32 class SimpleXMLRPCServer6(SimpleXMLRPCServer4):
33
34     address_family = socket.AF_INET6
35
36     def server_bind(self):
37         self.socket.setsockopt(SOL_IPV6, IPV6_V6ONLY, 1)
38         SimpleXMLRPCServer4.server_bind(self)
39
40
41 class main(object):
42
43     def __init__(self):
44         self.cert_duration = 365 * 86400
45         self.time_out = 45000
46         self.refresh_interval = 600
47         self.last_refresh = time.time()
48
49         utils.setupLog(3)
50
51         # Command line parsing
52         parser = utils.ArgParser(fromfile_prefix_chars='@',
53                 description='Peer discovery http server for re6stnet')
54         _ = parser.add_argument
55         _('--port', type=int, default=80, help='Port of the host server')
56         _('--db', required=True,
57                 help='Path to database file')
58         _('--ca', required=True,
59                 help='Path to ca.crt file')
60         _('--key', required=True,
61                 help='Path to certificate key')
62         _('--mailhost', required=True,
63                 help='SMTP server mail host')
64         _('--private',
65                 help='VPN IP of the node on which runs the registry')
66         self.config = parser.parse_args()
67
68         if self.config.private:
69             self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
70         else:
71             logging.warning('You have declared no private address'
72                     ', either this is the first start, or you should'
73                     'check you configuration')
74
75         # Database initializing
76         self.db = sqlite3.connect(self.config.db, isolation_level=None)
77         self.db.execute("""CREATE TABLE IF NOT EXISTS token (
78                         token text primary key not null,
79                         email text not null,
80                         prefix_len integer not null,
81                         date integer not null)""")
82         try:
83             self.db.execute("""CREATE TABLE cert (
84                                prefix text primary key not null,
85                                email text,
86                                cert text)""")
87         except sqlite3.OperationalError, e:
88             if e.args[0] != 'table cert already exists':
89                 raise RuntimeError
90         else:
91             self.db.execute("INSERT INTO cert VALUES ('',null,null)")
92
93         # Loading certificates
94         with open(self.config.ca) as f:
95             self.ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
96         with open(self.config.key) as f:
97             self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
98         # Get vpn network prefix
99         self.network = bin(self.ca.get_serial_number())[3:]
100         logging.info("Network prefix : %s/%u" % (self.network, len(self.network)))
101
102         # Starting server
103         server4 = SimpleXMLRPCServer4(('0.0.0.0', self.config.port), requestHandler=RequestHandler, allow_none=True)
104         server4.register_instance(self)
105         server6 = SimpleXMLRPCServer6(('::', self.config.port), requestHandler=RequestHandler, allow_none=True)
106         server6.register_instance(self)
107
108         # Main loop
109         while True:
110             try:
111                 r, w, e = select.select([server4, server6], [], [])
112             except (OSError, select.error) as e:
113                 if e.args[0] != errno.EINTR:
114                     raise
115             else:
116                 for r in r:
117                     r._handle_request_noblock()
118
119     def requestToken(self, handler, email):
120         while True:
121             # Generating token
122             token = ''.join(random.sample(string.ascii_lowercase, 8))
123             # Updating database
124             try:
125                 self.db.execute("INSERT INTO token VALUES (?,?,?,?)", (token, email, 16, int(time.time())))
126                 break
127             except sqlite3.IntegrityError:
128                 pass
129
130         # Creating and sending email
131         s = smtplib.SMTP(self.config.mailhost)
132         me = 'postmaster@re6st.net'
133         msg = MIMEText('Hello world !\nYour token : %s' % (token,))  # XXX
134         msg['Subject'] = '[re6stnet] Token Request'
135         msg['From'] = me
136         msg['To'] = email
137         s.sendmail(me, email, msg.as_string())
138         s.quit()
139
140     def _getPrefix(self, prefix_len):
141         max_len = 128 - len(self.network)
142         assert 0 < prefix_len <= max_len
143         try:
144             prefix, = self.db.execute("""SELECT prefix FROM cert WHERE length(prefix) <= ? AND cert is null
145                                          ORDER BY length(prefix) DESC""", (prefix_len,)).next()
146         except StopIteration:
147             logging.error('There are no more free /%s prefix available' % (prefix_len,))
148             raise
149         while len(prefix) < prefix_len:
150             self.db.execute("UPDATE cert SET prefix = ? WHERE prefix = ?", (prefix + '1', prefix))
151             prefix += '0'
152             self.db.execute("INSERT INTO cert VALUES (?,null,null)", (prefix,))
153         if len(prefix) < max_len or '1' in prefix:
154             return prefix
155         self.db.execute("UPDATE cert SET cert = 'reserved' WHERE prefix = ?", (prefix,))
156         return self._getPrefix(prefix_len)
157
158     def requestCertificate(self, handler, token, cert_req):
159         try:
160             req = crypto.load_certificate_request(crypto.FILETYPE_PEM, cert_req)
161             with self.db:
162                 try:
163                     token, email, prefix_len, _ = self.db.execute("SELECT * FROM token WHERE token = ?", (token,)).next()
164                 except StopIteration:
165                     logging.exception('Bad token (%s) in request' % (token,))
166                     raise
167                 self.db.execute("DELETE FROM token WHERE token = ?", (token,))
168
169                 # Get a new prefix
170                 prefix = self._getPrefix(prefix_len)
171
172                 # Create certificate
173                 cert = crypto.X509()
174                 #cert.set_serial_number(serial)
175                 cert.gmtime_adj_notBefore(0)
176                 cert.gmtime_adj_notAfter(self.cert_duration)
177                 cert.set_issuer(self.ca.get_subject())
178                 subject = req.get_subject()
179                 subject.CN = "%u/%u" % (int(prefix, 2), prefix_len)
180                 cert.set_subject(subject)
181                 cert.set_pubkey(req.get_pubkey())
182                 cert.sign(self.key, 'sha1')
183                 cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
184
185                 # Insert certificate into db
186                 self.db.execute("UPDATE cert SET email = ?, cert = ? WHERE prefix = ?", (email, cert, prefix))
187
188             return cert
189         except:
190             traceback.print_exc()
191             raise
192
193     def getCa(self, handler):
194         return crypto.dump_certificate(crypto.FILETYPE_PEM, self.ca)
195
196     def getPrivateAddress(self, handler):
197         return self.config.private
198
199     def getBootstrapPeer(self, handler, client_prefix):
200         cert, = self.db.execute("SELECT cert FROM cert WHERE prefix = ?",
201                 (client_prefix,)).next()
202         address = self.config.private, tunnel.PORT
203         self.sock.sendto('\2', address)
204         peer = None
205         while select.select([self.sock], [], [], peer is None)[0]:
206             msg = self.sock.recv(1<<16)
207             if msg[0] == '\1':
208                 try:
209                     peer = msg[1:].split('\n')[-2]
210                 except IndexError:
211                     peer = ''
212         if peer is None:
213             raise EnvironmentError("Timeout while querying [%s]:%u", *address)
214         if not peer or peer.split()[0] == client_prefix:
215             raise LookupError("No bootstrap peer found")
216         logging.info("Sending bootstrap peer: %s", peer)
217         r, w = os.pipe()
218         try:
219             threading.Thread(target=os.write, args=(w, cert)).start()
220             p = subprocess.Popen(('openssl', 'rsautl', '-encrypt', '-certin', '-inkey', '/proc/self/fd/%u' % r),
221                 stdin=subprocess.PIPE, stdout=subprocess.PIPE)
222             return xmlrpclib.Binary(p.communicate(peer)[0])
223         finally:
224             os.close(r)
225             os.close(w)
226
227     def topology(self, handler):
228         if handler.client_address[0] in ('127.0.0.1', '::'):
229             is_registry = utils.binFromIp(self.config.private
230                 )[len(self.network):].startswith
231             peers = deque('%u/%u' % (int(x, 2), len(x))
232                 for x, in self.db.execute("SELECT prefix FROM cert")
233                 if is_registry(x))
234             assert len(peers) == 1
235             cookie = hex(random.randint(0, 1<<32))[2:]
236             graph = dict.fromkeys(peers)
237             asked = 0
238             while True:
239                 r, w, _ = select.select([self.sock],
240                     [self.sock] if peers else [], [], 1)
241                 if r:
242                     answer = self.sock.recv(1<<16)
243                     if answer[0] == '\xfe':
244                         answer = answer[1:].split('\n')[:-1]
245                         if len(answer) >= 3 and answer[0] == cookie:
246                             x = answer[3:]
247                             assert answer[1] not in x, (answer, graph)
248                             graph[answer[1]] = x[:int(answer[2])]
249                             x = set(x).difference(graph)
250                             peers += x
251                             graph.update(dict.fromkeys(x))
252                 if w:
253                     x = utils.binFromSubnet(peers.popleft())
254                     x = utils.ipFromBin(self.network + x)
255                     try:
256                         self.sock.sendto('\xff%s\n' % cookie, (x, tunnel.PORT))
257                     except socket.error:
258                         pass
259                 elif not r:
260                     break
261             return graph
262
263 if __name__ == "__main__":
264     main()