The trafic is now taken into account to choose the tunnel to delete
[re6stnet.git] / registry.py
1 #!/usr/bin/env python
2 import argparse, math, random, select, smtplib, sqlite3, string, socket, time, traceback, errno
3 from SimpleXMLRPCServer import SimpleXMLRPCServer, SimpleXMLRPCRequestHandler
4 from email.mime.text import MIMEText
5 from OpenSSL import crypto
6 import utils
7
8
9 # Fix for librpcxml to avoid doing reverse dns on each request
10 # it was causing a 10s delay on each request when no reverse DNS was avalaible
11 # for tis IP
12 import BaseHTTPServer
13
14
15 def not_insane_address_string(self):
16 host, port = self.client_address[:2]
17 return '%s (reverse DNS disabled)' % host # used to call: socket.getfqdn(host)
18
19 BaseHTTPServer.BaseHTTPRequestHandler.address_string = not_insane_address_string
20
21
22 # To generate server ca and key with serial for 2001:db8:42::/48
23 # openssl req -nodes -new -x509 -key ca.key -set_serial 0x120010db80042 -days 365 -out ca.crt
24
25 IPV6_V6ONLY = 26
26 SOL_IPV6 = 41
27
28
29 class RequestHandler(SimpleXMLRPCRequestHandler):
30
31 def _dispatch(self, method, params):
32 return self.server._dispatch(method, (self,) + params)
33
34
35 class SimpleXMLRPCServer4(SimpleXMLRPCServer):
36
37 allow_reuse_address = True
38
39
40 class SimpleXMLRPCServer6(SimpleXMLRPCServer4):
41
42 address_family = socket.AF_INET6
43
44 def server_bind(self):
45 self.socket.setsockopt(SOL_IPV6, IPV6_V6ONLY, 1)
46 SimpleXMLRPCServer4.server_bind(self)
47
48
49 class main(object):
50
51 def __init__(self):
52 self.cert_duration = 365 * 86400
53 self.time_out = 86400
54 self.refresh_interval = 600
55 self.last_refresh = time.time()
56
57 # Command line parsing
58 parser = argparse.ArgumentParser(
59 description='Peer discovery http server for vifibnet')
60 _ = parser.add_argument
61 _('port', type=int, help='Port of the host server')
62 _('--db', required=True,
63 help='Path to database file')
64 _('--ca', required=True,
65 help='Path to ca.crt file')
66 _('--key', required=True,
67 help='Path to certificate key')
68 _('--mailhost', required=True,
69 help='SMTP server mail host')
70 self.config = parser.parse_args()
71
72 # Database initializing
73 self.db = sqlite3.connect(self.config.db, isolation_level=None)
74 self.db.execute("""CREATE TABLE IF NOT EXISTS peers (
75 prefix text primary key not null,
76 address text not null,
77 date integer default (strftime('%s','now')))""")
78 self.db.execute("CREATE INDEX IF NOT EXISTS peers_ping ON peers(date)")
79 self.db.execute("""CREATE TABLE IF NOT EXISTS tokens (
80 token text primary key not null,
81 email text not null,
82 prefix_len integer not null,
83 date integer not null)""")
84 try:
85 self.db.execute("""CREATE TABLE vpn (
86 prefix text primary key not null,
87 email text,
88 cert text)""")
89 except sqlite3.OperationalError, e:
90 if e.args[0] != 'table vpn already exists':
91 raise RuntimeError
92 else:
93 self.db.execute("INSERT INTO vpn VALUES ('',null,null)")
94
95 # Loading certificates
96 with open(self.config.ca) as f:
97 self.ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
98 with open(self.config.key) as f:
99 self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
100 # Get vpn network prefix
101 self.network = bin(self.ca.get_serial_number())[3:]
102 print "Network prefix : %s/%u" % (self.network, len(self.network))
103
104 # Starting server
105 server4 = SimpleXMLRPCServer4(('0.0.0.0', self.config.port), requestHandler=RequestHandler, allow_none=True)
106 server4.register_instance(self)
107 server6 = SimpleXMLRPCServer6(('::', self.config.port), requestHandler=RequestHandler, allow_none=True)
108 server6.register_instance(self)
109
110 # Main loop
111 while True:
112 try:
113 r, w, e = select.select([server4, server6], [], [])
114 except (OSError, select.error) as e:
115 if e.args[0] != errno.EINTR:
116 raise
117 else:
118 for r in r:
119 r._handle_request_noblock()
120
121 def requestToken(self, handler, email):
122 while True:
123 # Generating token
124 token = ''.join(random.sample(string.ascii_lowercase, 8))
125 # Updating database
126 try:
127 self.db.execute("INSERT INTO tokens VALUES (?,?,?,?)", (token, email, 16, int(time.time())))
128 break
129 except sqlite3.IntegrityError:
130 pass
131
132 # Creating and sending email
133 s = smtplib.SMTP(self.config.mailhost)
134 me = 'postmaster@vifibnet.com'
135 msg = MIMEText('Hello world !\nYour token : %s' % (token,))
136 msg['Subject'] = '[Vifibnet] Token Request'
137 msg['From'] = me
138 msg['To'] = email
139 s.sendmail(me, email, msg.as_string())
140 s.quit()
141
142 def _getPrefix(self, prefix_len):
143 assert 0 < prefix_len <= 128 - len(self.network)
144 for prefix, in self.db.execute("""SELECT prefix FROM vpn WHERE length(prefix) <= ? AND cert is null
145 ORDER BY length(prefix) DESC""", (prefix_len,)):
146 while len(prefix) < prefix_len:
147 self.db.execute("UPDATE vpn SET prefix = ? WHERE prefix = ?", (prefix + '1', prefix))
148 prefix += '0'
149 self.db.execute("INSERT INTO vpn VALUES (?,null,null)", (prefix,))
150 return prefix
151 raise RuntimeError # TODO: raise better exception
152
153 def requestCertificate(self, handler, token, cert_req):
154 try:
155 req = crypto.load_certificate_request(crypto.FILETYPE_PEM, cert_req)
156 with self.db:
157 try:
158 token, email, prefix_len, _ = self.db.execute("SELECT * FROM tokens WHERE token = ?", (token,)).next()
159 except StopIteration:
160 # TODO: return nice error message
161 raise
162 self.db.execute("DELETE FROM tokens WHERE token = ?", (token,))
163
164 # Get a new prefix
165 prefix = self._getPrefix(prefix_len)
166
167 # Create certificate
168 cert = crypto.X509()
169 #cert.set_serial_number(serial)
170 cert.gmtime_adj_notBefore(0)
171 cert.gmtime_adj_notAfter(self.cert_duration)
172 cert.set_issuer(self.ca.get_subject())
173 subject = req.get_subject()
174 subject.CN = "%u/%u" % (int(prefix, 2), prefix_len)
175 cert.set_subject(subject)
176 cert.set_pubkey(req.get_pubkey())
177 cert.sign(self.key, 'sha1')
178 cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
179
180 # Insert certificate into db
181 self.db.execute("UPDATE vpn SET email = ?, cert = ? WHERE prefix = ?", (email, cert, prefix))
182
183 return cert
184 except:
185 traceback.print_exc()
186 raise
187
188 def getCa(self, handler):
189 return crypto.dump_certificate(crypto.FILETYPE_PEM, self.ca)
190
191 def getBootstrapPeer(self, handler):
192 # TODO: Insert a flag column for bootstrap ready servers in peers
193 # ( servers which shouldn't go down or change ip and port as opposed to servers owned by particulars )
194 # that way, we also ascertain that the server sent is not the new node....
195 prefix, address = self.db.execute("SELECT prefix, address FROM peers ORDER BY random() LIMIT 1").next()
196 print "Sending bootstrap peer (%s, %s)" % (prefix, str(address))
197 return prefix, address
198
199 def declare(self, handler, address):
200 print "declaring new node"
201 client_address, address = address
202 #client_address, _ = handler.client_address
203 client_ip = utils.binFromIp(client_address)
204 if client_ip.startswith(self.network):
205 prefix = client_ip[len(self.network):]
206 prefix, = self.db.execute("SELECT prefix FROM vpn WHERE prefix <= ? ORDER BY prefix DESC LIMIT 1", (prefix,)).next()
207 self.db.execute("INSERT OR REPLACE INTO peers (prefix, address) VALUES (?,?)", (prefix, address))
208 return True
209 else:
210 # TODO: use log + DO NOT PRINT BINARY IP
211 print "Unauthorized connection from %s which does not start with %s" % (client_ip, self.network)
212 return False
213
214 def getPeerList(self, handler, n, client_address):
215 assert 0 < n < 1000
216 client_ip = utils.binFromIp(client_address)
217 if client_ip.startswith(self.network):
218 if time.time() > self.last_refresh + self.refresh_interval:
219 print "refreshing peers for dead ones"
220 self.db.execute("DELETE FROM peers WHERE ( date + ? ) <= CAST (strftime('%s', 'now') AS INTEGER)", (self.time_out,))
221 self.last_refesh = time.time()
222 print "sending peers"
223 return self.db.execute("SELECT prefix, address FROM peers ORDER BY random() LIMIT ?", (n,)).fetchall()
224 else:
225 # TODO: use log + DO NOT PRINT BINARY IP
226 print "Unauthorized connection from %s which does not start with %s" % (client_ip, self.network)
227 raise RuntimeError
228
229 if __name__ == "__main__":
230 main()