Egg packaging
[re6stnet.git] / re6st-registry
1 #!/usr/bin/env python
2 import argparse, random, select, smtplib, sqlite3, string, socket
3 import subprocess, time, threading, traceback, errno, logging, os, xmlrpclib
4 from SimpleXMLRPCServer import SimpleXMLRPCServer, SimpleXMLRPCRequestHandler
5 from email.mime.text import MIMEText
6 from OpenSSL import crypto
7 from re6st  import utils
8
9 # To generate server ca and key with serial for 2001:db8:42::/48
10 # openssl req -nodes -new -x509 -key ca.key -set_serial 0x120010db80042 -days 365 -out ca.crt
11
12 IPV6_V6ONLY = 26
13 SOL_IPV6 = 41
14
15
16 class RequestHandler(SimpleXMLRPCRequestHandler):
17
18     def _dispatch(self, method, params):
19         return self.server._dispatch(method, (self,) + params)
20
21
22 class SimpleXMLRPCServer4(SimpleXMLRPCServer):
23
24     allow_reuse_address = True
25
26
27 class SimpleXMLRPCServer6(SimpleXMLRPCServer4):
28
29     address_family = socket.AF_INET6
30
31     def server_bind(self):
32         self.socket.setsockopt(SOL_IPV6, IPV6_V6ONLY, 1)
33         SimpleXMLRPCServer4.server_bind(self)
34
35
36 class main(object):
37
38     def __init__(self):
39         self.cert_duration = 365 * 86400
40         self.time_out = 86400
41         self.refresh_interval = 600
42         self.last_refresh = time.time()
43
44         utils.setupLog(3)
45
46         # Command line parsing
47         parser = argparse.ArgumentParser(
48                 description='Peer discovery http server for re6stnet')
49         _ = parser.add_argument
50         _('port', type=int, help='Port of the host server')
51         _('--db', required=True,
52                 help='Path to database file')
53         _('--ca', required=True,
54                 help='Path to ca.crt file')
55         _('--key', required=True,
56                 help='Path to certificate key')
57         _('--mailhost', required=True,
58                 help='SMTP server mail host')
59         _('--bootstrap', action="append",
60                 help='''VPN prefix of the peers to send as bootstrap peer,
61                         instead of random ones''')
62         _('--private',
63                 help='VPN IP of the node on which runs the registry')
64         self.config = parser.parse_args()
65
66         # Database initializing
67         self.db = sqlite3.connect(self.config.db, isolation_level=None)
68         self.db.execute("""CREATE TABLE IF NOT EXISTS peers (
69                         prefix text primary key not null,
70                         address text not null,
71                         date integer default (strftime('%s','now')))""")
72         self.db.execute("CREATE INDEX IF NOT EXISTS peers_ping ON peers(date)")
73         self.db.execute("""CREATE TABLE IF NOT EXISTS tokens (
74                         token text primary key not null,
75                         email text not null,
76                         prefix_len integer not null,
77                         date integer not null)""")
78         try:
79             self.db.execute("""CREATE TABLE vpn (
80                                prefix text primary key not null,
81                                email text,
82                                cert text)""")
83         except sqlite3.OperationalError, e:
84             if e.args[0] != 'table vpn already exists':
85                 raise RuntimeError
86         else:
87             self.db.execute("INSERT INTO vpn VALUES ('',null,null)")
88
89         # Loading certificates
90         with open(self.config.ca) as f:
91             self.ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
92         with open(self.config.key) as f:
93             self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
94         # Get vpn network prefix
95         self.network = bin(self.ca.get_serial_number())[3:]
96         logging.info("Network prefix : %s/%u" % (self.network, len(self.network)))
97
98         # Starting server
99         server4 = SimpleXMLRPCServer4(('0.0.0.0', self.config.port), requestHandler=RequestHandler, allow_none=True)
100         server4.register_instance(self)
101         server6 = SimpleXMLRPCServer6(('::', self.config.port), requestHandler=RequestHandler, allow_none=True)
102         server6.register_instance(self)
103
104         # Main loop
105         while True:
106             try:
107                 r, w, e = select.select([server4, server6], [], [])
108             except (OSError, select.error) as e:
109                 if e.args[0] != errno.EINTR:
110                     raise
111             else:
112                 for r in r:
113                     r._handle_request_noblock()
114
115     def requestToken(self, handler, email):
116         while True:
117             # Generating token
118             token = ''.join(random.sample(string.ascii_lowercase, 8))
119             # Updating database
120             try:
121                 self.db.execute("INSERT INTO tokens VALUES (?,?,?,?)", (token, email, 16, int(time.time())))
122                 break
123             except sqlite3.IntegrityError:
124                 pass
125
126         # Creating and sending email
127         s = smtplib.SMTP(self.config.mailhost)
128         me = 'postmaster@re6st.net'
129         msg = MIMEText('Hello world !\nYour token : %s' % (token,))  # XXX
130         msg['Subject'] = '[re6stnet] Token Request'
131         msg['From'] = me
132         msg['To'] = email
133         s.sendmail(me, email, msg.as_string())
134         s.quit()
135
136     def _getPrefix(self, prefix_len):
137         max_len = 128 - len(self.network)
138         assert 0 < prefix_len <= max_len
139         try:
140             prefix, = self.db.execute("""SELECT prefix FROM vpn WHERE length(prefix) <= ? AND cert is null
141                                          ORDER BY length(prefix) DESC""", (prefix_len,)).next()
142         except StopIteration:
143             logging.error('There are no more free /%s prefix available' % (prefix_len,))
144             raise
145         while len(prefix) < prefix_len:
146             self.db.execute("UPDATE vpn SET prefix = ? WHERE prefix = ?", (prefix + '1', prefix))
147             prefix += '0'
148             self.db.execute("INSERT INTO vpn VALUES (?,null,null)", (prefix,))
149         if len(prefix) < max_len or '1' in prefix:
150             return prefix
151         self.db.execute("UPDATE vpn SET cert = 'reserved' WHERE prefix = ?", (prefix,))
152         return self._getPrefix(prefix_len)
153
154     def requestCertificate(self, handler, token, cert_req):
155         try:
156             req = crypto.load_certificate_request(crypto.FILETYPE_PEM, cert_req)
157             with self.db:
158                 try:
159                     token, email, prefix_len, _ = self.db.execute("SELECT * FROM tokens WHERE token = ?", (token,)).next()
160                 except StopIteration:
161                     logging.exception('Bad token (%s) in request' % (token,))
162                     raise
163                 self.db.execute("DELETE FROM tokens WHERE token = ?", (token,))
164
165                 # Get a new prefix
166                 prefix = self._getPrefix(prefix_len)
167
168                 # Create certificate
169                 cert = crypto.X509()
170                 #cert.set_serial_number(serial)
171                 cert.gmtime_adj_notBefore(0)
172                 cert.gmtime_adj_notAfter(self.cert_duration)
173                 cert.set_issuer(self.ca.get_subject())
174                 subject = req.get_subject()
175                 subject.CN = "%u/%u" % (int(prefix, 2), prefix_len)
176                 cert.set_subject(subject)
177                 cert.set_pubkey(req.get_pubkey())
178                 cert.sign(self.key, 'sha1')
179                 cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
180
181                 # Insert certificate into db
182                 self.db.execute("UPDATE vpn SET email = ?, cert = ? WHERE prefix = ?", (email, cert, prefix))
183
184             return cert
185         except:
186             traceback.print_exc()
187             raise
188
189     def getCa(self, handler):
190         return crypto.dump_certificate(crypto.FILETYPE_PEM, self.ca)
191
192     def getPrivateAddress(self, handler):
193         return 'http://[%s]:%u' % (self.config.private, self.config.port)
194
195     def _randomPeer(self):
196         return self.db.execute("""SELECT prefix, address
197                         FROM peers ORDER BY random() LIMIT 1""").next()
198
199     def getBootstrapPeer(self, handler, client_prefix):
200         cert, = self.db.execute("SELECT cert FROM vpn WHERE prefix = ?",
201                 (client_prefix,)).next()
202         logging.trace('Getting bootpeer info...')
203         if self.config.bootstrap:
204             bootpeer = random.choice(self.config.bootstrap)
205             try:
206                 prefix, address = self.db.execute("""SELECT prefix, address
207                             FROM peers WHERE prefix = ?""", (bootpeer,)).next()
208             except StopIteration:
209                 logging.info('Bootstrap peer %s unknown, sending random peer'
210                         % hex(int(bootpeer, 2))[2:])
211                 prefix, address = self._randomPeer()
212         else:
213             prefix, address = self._randomPeer()
214         logging.trace('Gotten bootpeer info from db')
215         r, w = os.pipe()
216         try:
217             threading.Thread(target=os.write, args=(w, cert)).start()
218             p = subprocess.Popen(('openssl', 'rsautl', '-encrypt', '-certin', '-inkey', '/proc/self/fd/%u' % r),
219                 stdin=subprocess.PIPE, stdout=subprocess.PIPE)
220             logging.info("Sending bootstrap peer (%s, %s)" % (prefix, address))
221             return xmlrpclib.Binary(p.communicate('%s %s' % (prefix, address))[0])
222         finally:
223             os.close(r)
224             os.close(w)
225
226     def declare(self, handler, address):
227         print "declaring new node"
228         client_address, address = address
229         #client_address, _ = handler.client_address
230         client_ip = utils.binFromIp(client_address)
231         if client_ip.startswith(self.network):
232             prefix = client_ip[len(self.network):]
233             prefix, = self.db.execute("SELECT prefix FROM vpn WHERE prefix <= ? ORDER BY prefix DESC LIMIT 1", (prefix,)).next()
234             self.db.execute("INSERT OR REPLACE INTO peers (prefix, address) VALUES (?,?)", (prefix, address))
235             return True
236         else:
237             logging.warning("Unauthorized connection from %s which does not start with %s"
238                     % (utils.ipFromBin(client_ip), utils.ipFromBin(self.network.ljust(128, '0'))))
239             return False
240
241     def getPeerList(self, handler, n, client_address):
242         assert 0 < n < 1000
243         client_ip = utils.binFromIp(client_address)
244         if client_ip.startswith(self.network):
245             if time.time() > self.last_refresh + self.refresh_interval:
246                 print "refreshing peers for dead ones"
247                 self.db.execute("DELETE FROM peers WHERE ( date + ? ) <= CAST (strftime('%s', 'now') AS INTEGER)", (self.time_out,))
248                 self.last_refesh = time.time()
249             print "sending peers"
250             return self.db.execute("SELECT prefix, address FROM peers ORDER BY random() LIMIT ?", (n,)).fetchall()
251         else:
252             logging.warning("Unauthorized connection from %s which does not start with %s"
253                     % (utils.ipFromBin(client_ip), utils.ipFromBin(self.network.ljust(128, '0'))))
254             raise RuntimeError
255
256 if __name__ == "__main__":
257     main()