Try not to break server connections in multi-gateway mode
[re6stnet.git] / re6st-registry
1 #!/usr/bin/python
2 import errno, logging, mailbox, os, random, select
3 import smtplib, socket, sqlite3, string, subprocess, sys
4 import threading, time, traceback, xmlrpclib
5 from collections import deque
6 from SimpleXMLRPCServer import SimpleXMLRPCServer, SimpleXMLRPCRequestHandler
7 from email.mime.text import MIMEText
8 from OpenSSL import crypto
9 from re6st import tunnel, utils
10
11 # To generate server ca and key with serial for 2001:db8:42::/48
12 # openssl req -nodes -new -x509 -key ca.key -set_serial 0x120010db80042 -days 365 -out ca.crt
13
14 IPV6_V6ONLY = 26
15 SOL_IPV6 = 41
16
17
18 class RequestHandler(SimpleXMLRPCRequestHandler):
19
20     def address_string(self):
21         # Workaround for http://bugs.python.org/issue6085
22         return self.client_address[0]
23
24     def _dispatch(self, method, params):
25         logging.debug('%s%r', method, params)
26         return self.server._dispatch(method, (self,) + params)
27
28 class SimpleXMLRPCServer4(SimpleXMLRPCServer):
29
30     allow_reuse_address = True
31
32
33 class SimpleXMLRPCServer6(SimpleXMLRPCServer4):
34
35     address_family = socket.AF_INET6
36
37     def server_bind(self):
38         self.socket.setsockopt(SOL_IPV6, IPV6_V6ONLY, 1)
39         SimpleXMLRPCServer4.server_bind(self)
40
41
42 class main(object):
43
44     def __init__(self):
45         self.cert_duration = 365 * 86400
46         self.time_out = 45000
47         self.refresh_interval = 600
48         self.last_refresh = time.time()
49
50
51         # Command line parsing
52         parser = utils.ArgParser(fromfile_prefix_chars='@',
53             description="re6stnet registry used to bootstrap nodes"
54                         " and deliver certificates.")
55         _ = parser.add_argument
56         _('--port', type=int, default=80,
57             help="Port on which the server will listen.")
58         _('-4', dest='bind4', default='0.0.0.0',
59             help="Bind server to this IPv4.")
60         _('-6', dest='bind6', default='::',
61             help="Bind server to this IPv6.")
62         _('--db', default='/var/lib/re6stnet/registry.db',
63             help="Path to SQLite database file. It is automatically initialized"
64                  " if the file does not exist.")
65         _('--ca', required=True, help=parser._ca_help)
66         _('--key', required=True,
67                 help="CA private key in .pem format.")
68         _('--mailhost', required=True,
69                 help="SMTP host to send confirmation emails. For debugging"
70                      " purpose, it can also be an absolute or existing path to"
71                      " a mailbox file")
72         _('--private',
73                 help="re6stnet IP of the node on which runs the registry."
74                      " Required for normal operation.")
75         _('--prefix-length', default=16, type=int,
76                 help="Default length of allocated prefixes.")
77         _('--anonymous-prefix-length', type=int,
78                 help="Length of allocated anonymous prefixes."
79                      " If 0 or unset, registration by email is required")
80         _('-l', '--logfile', default='/var/log/re6stnet/registry.log',
81                 help="Path to logging file.")
82         _('-v', '--verbose', default=1, type=int,
83                 help="Log level. 0 disables logging."
84                      " Use SIGUSR1 to reopen log.")
85         self.config = parser.parse_args()
86
87         utils.setupLog(self.config.verbose, self.config.logfile)
88
89         if self.config.private:
90             self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
91         else:
92             logging.warning('You have declared no private address'
93                     ', either this is the first start, or you should'
94                     'check you configuration')
95
96         # Database initializing
97         utils.makedirs(os.path.dirname(self.config.db))
98         self.db = sqlite3.connect(self.config.db, isolation_level=None)
99         self.db.execute("""CREATE TABLE IF NOT EXISTS token (
100                         token text primary key not null,
101                         email text not null,
102                         prefix_len integer not null,
103                         date integer not null)""")
104         try:
105             self.db.execute("""CREATE TABLE cert (
106                                prefix text primary key not null,
107                                email text,
108                                cert text)""")
109         except sqlite3.OperationalError, e:
110             if e.args[0] != 'table cert already exists':
111                 raise RuntimeError
112         else:
113             self.db.execute("INSERT INTO cert VALUES ('',null,null)")
114
115         # Loading certificates
116         with open(self.config.ca) as f:
117             self.ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
118         with open(self.config.key) as f:
119             self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
120         # Get vpn network prefix
121         self.network = bin(self.ca.get_serial_number())[3:]
122         logging.info("Network: %s/%u", utils.ipFromBin(self.network),
123                                        len(self.network))
124         self._email = self.ca.get_subject().emailAddress
125
126         # Starting server
127         server_list = []
128         if self.config.bind4:
129             server4 = SimpleXMLRPCServer4((self.config.bind4, self.config.port),
130                 requestHandler=RequestHandler, allow_none=True)
131             server4.register_instance(self)
132             server_list.append(server4)
133         if self.config.bind6:
134             server6 = SimpleXMLRPCServer6((self.config.bind6, self.config.port),
135                 requestHandler=RequestHandler, allow_none=True)
136             server6.register_instance(self)
137             server_list.append(server6)
138
139         if len(server_list) == 1:
140             server_list[0].serve_forever()
141         else:
142             while True:
143                 try:
144                     r = select.select(server_list[:], [], [])[0]
145                 except select.error as e:
146                     if e.args[0] != errno.EINTR:
147                         raise
148                 else:
149                     for r in r:
150                         r._handle_request_noblock()
151
152     def requestToken(self, handler, email):
153         while True:
154             # Generating token
155             token = ''.join(random.sample(string.ascii_lowercase, 8))
156             args = token, email, self.config.prefix_length, int(time.time())
157             # Updating database
158             try:
159                 self.db.execute("INSERT INTO token VALUES (?,?,?,?)", args)
160                 break
161             except sqlite3.IntegrityError:
162                 pass
163
164         # Creating and sending email
165         msg = MIMEText('Hello, your token to join re6st network is: %s\n'
166                        % token)
167         msg['Subject'] = '[re6stnet] Token Request'
168         if self._email:
169             msg['From'] = self._email
170         msg['To'] = email
171         if os.path.isabs(self.config.mailhost) or \
172            os.path.isfile(self.config.mailhost):
173             m = mailbox.mbox(self.config.mailhost)
174             try:
175                 m.add(msg)
176             finally:
177                 m.close()
178         else:
179             s = smtplib.SMTP(self.config.mailhost)
180             s.sendmail(self._email, email, msg.as_string())
181             s.quit()
182
183     def _getPrefix(self, prefix_len):
184         max_len = 128 - len(self.network)
185         assert 0 < prefix_len <= max_len
186         try:
187             prefix, = self.db.execute("""SELECT prefix FROM cert WHERE length(prefix) <= ? AND cert is null
188                                          ORDER BY length(prefix) DESC""", (prefix_len,)).next()
189         except StopIteration:
190             logging.error('No more free /%u prefix available', prefix_len)
191             raise
192         while len(prefix) < prefix_len:
193             self.db.execute("UPDATE cert SET prefix = ? WHERE prefix = ?", (prefix + '1', prefix))
194             prefix += '0'
195             self.db.execute("INSERT INTO cert VALUES (?,null,null)", (prefix,))
196         if len(prefix) < max_len or '1' in prefix:
197             return prefix
198         self.db.execute("UPDATE cert SET cert = 'reserved' WHERE prefix = ?", (prefix,))
199         return self._getPrefix(prefix_len)
200
201     def requestCertificate(self, handler, token, cert_req):
202         try:
203             req = crypto.load_certificate_request(crypto.FILETYPE_PEM, cert_req)
204             with self.db:
205                 if token is None:
206                     prefix_len = self.config.anonymous_prefix_length
207                     if not prefix_len:
208                         return
209                     email = None
210                 else:
211                     try:
212                         token, email, prefix_len, _ = self.db.execute(
213                             "SELECT * FROM token WHERE token = ?",
214                             (token,)).next()
215                     except StopIteration:
216                         return
217                     self.db.execute("DELETE FROM token WHERE token = ?", (token,))
218
219                 # Get a new prefix
220                 prefix = self._getPrefix(prefix_len)
221
222                 # Create certificate
223                 cert = crypto.X509()
224                 cert.set_serial_number(0) # required for libssl < 1.0
225                 cert.gmtime_adj_notBefore(0)
226                 cert.gmtime_adj_notAfter(self.cert_duration)
227                 cert.set_issuer(self.ca.get_subject())
228                 subject = req.get_subject()
229                 subject.CN = "%u/%u" % (int(prefix, 2), prefix_len)
230                 cert.set_subject(subject)
231                 cert.set_pubkey(req.get_pubkey())
232                 cert.sign(self.key, 'sha1')
233                 cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
234
235                 # Insert certificate into db
236                 self.db.execute("UPDATE cert SET email = ?, cert = ? WHERE prefix = ?", (email, cert, prefix))
237
238             return cert
239         except Exception:
240             f = traceback.format_exception(*sys.exc_info())
241             logging.error('%s%s', f.pop(), ''.join(f))
242             raise
243
244     def getCa(self, handler):
245         return crypto.dump_certificate(crypto.FILETYPE_PEM, self.ca)
246
247     def getPrivateAddress(self, handler):
248         return self.config.private
249
250     def getBootstrapPeer(self, handler, client_prefix):
251         cert, = self.db.execute("SELECT cert FROM cert WHERE prefix = ?",
252                 (client_prefix,)).next()
253         address = self.config.private, tunnel.PORT
254         self.sock.sendto('\2', address)
255         peer = None
256         while select.select([self.sock], [], [], peer is None)[0]:
257             msg = self.sock.recv(1<<16)
258             if msg[0] == '\1':
259                 try:
260                     peer = msg[1:].split('\n')[-2]
261                 except IndexError:
262                     peer = ''
263         if peer is None:
264             raise EnvironmentError("Timeout while querying [%s]:%u" % address)
265         if not peer or peer.split()[0] == client_prefix:
266             raise LookupError("No bootstrap peer found")
267         logging.info("Sending bootstrap peer: %s", peer)
268         r, w = os.pipe()
269         try:
270             threading.Thread(target=os.write, args=(w, cert)).start()
271             p = subprocess.Popen(('openssl', 'rsautl', '-encrypt', '-certin', '-inkey', '/proc/self/fd/%u' % r),
272                 stdin=subprocess.PIPE, stdout=subprocess.PIPE)
273             return xmlrpclib.Binary(p.communicate(peer)[0])
274         finally:
275             os.close(r)
276             os.close(w)
277
278     def topology(self, handler):
279         if handler.client_address[0] in ('127.0.0.1', '::'):
280             is_registry = utils.binFromIp(self.config.private
281                 )[len(self.network):].startswith
282             peers = deque('%u/%u' % (int(x, 2), len(x))
283                 for x, in self.db.execute("SELECT prefix FROM cert")
284                 if is_registry(x))
285             assert len(peers) == 1
286             cookie = hex(random.randint(0, 1<<32))[2:]
287             graph = dict.fromkeys(peers)
288             asked = 0
289             while True:
290                 r, w, _ = select.select([self.sock],
291                     [self.sock] if peers else [], [], 1)
292                 if r:
293                     answer = self.sock.recv(1<<16)
294                     if answer[0] == '\xfe':
295                         answer = answer[1:].split('\n')[:-1]
296                         if len(answer) >= 3 and answer[0] == cookie:
297                             x = answer[3:]
298                             assert answer[1] not in x, (answer, graph)
299                             graph[answer[1]] = x[:int(answer[2])]
300                             x = set(x).difference(graph)
301                             peers += x
302                             graph.update(dict.fromkeys(x))
303                 if w:
304                     x = utils.binFromSubnet(peers.popleft())
305                     x = utils.ipFromBin(self.network + x)
306                     try:
307                         self.sock.sendto('\xff%s\n' % cookie, (x, tunnel.PORT))
308                     except socket.error:
309                         pass
310                 elif not r:
311                     break
312             return graph
313
314 if __name__ == "__main__":
315     main()