Do not fail if OpenVPN calls 'disconnect' hook without having called 'connect' hook...
[re6stnet.git] / re6st / tunnel.py
1 import logging, random, socket, subprocess, time
2 from collections import deque
3 from itertools import chain
4 from . import plib, utils
5
6 PORT = 326
7 RTF_CACHE = 0x01000000 # cache entry
8
9 # Be careful the refresh interval should let the routes be established
10
11 class Connection:
12
13 def __init__(self, address, write_pipe, timeout, iface, prefix, encrypt,
14 ovpn_args):
15 self.process = plib.client(iface, address, encrypt,
16 '--tls-remote', '%u/%u' % (int(prefix, 2), len(prefix)),
17 '--connect-retry-max', '3', '--tls-exit',
18 '--ping-exit', str(timeout),
19 '--route-up', '%s %u' % (plib.ovpn_client, write_pipe),
20 *ovpn_args)
21 self.iface = iface
22 self.routes = 0
23 self._prefix = prefix
24
25 def refresh(self):
26 # Check that the connection is alive
27 if self.process.poll() != None:
28 logging.info('Connection with %s has failed with return code %s',
29 self._prefix, self.process.returncode)
30 return False
31 return True
32
33
34 class TunnelManager(object):
35
36 def __init__(self, write_pipe, peer_db, openvpn_args, timeout,
37 refresh, client_count, iface_list, network, prefix,
38 address, ip_changed, encrypt):
39 self._write_pipe = write_pipe
40 self._peer_db = peer_db
41 self._connecting = set()
42 self._connection_dict = {}
43 self._disconnected = None
44 self._distant_peers = []
45 self._iface_to_prefix = {}
46 self._ovpn_args = openvpn_args
47 self._timeout = timeout
48 self._refresh_time = refresh
49 self._network = network
50 self._iface_list = iface_list
51 self._prefix = prefix
52 self._address = utils.address_str(address)
53 self._ip_changed = ip_changed
54 self._encrypt = encrypt
55 self._served = set()
56
57 self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
58 # See also http://stackoverflow.com/questions/597225/
59 # about binding and anycast.
60 self.sock.bind(('::', PORT))
61
62 self.next_refresh = time.time()
63 self._next_tunnel_refresh = time.time()
64
65 self._client_count = client_count
66 self._refresh_count = 1
67 self.new_iface_list = deque('re6stnet' + str(i)
68 for i in xrange(1, self._client_count + 1))
69 self._free_iface_list = []
70
71 def _tuntap(self, iface=None):
72 if iface:
73 self.new_iface_list.appendleft(iface)
74 action = 'del'
75 else:
76 iface = self.new_iface_list.popleft()
77 action = 'add'
78 args = 'ip', 'tuntap', action, 'dev', iface, 'mode', 'tap'
79 logging.debug('%r', args)
80 subprocess.call(args)
81 return iface
82
83 def delInterfaces(self):
84 iface_list = self._free_iface_list
85 iface_list += self._iface_to_prefix
86 self._iface_to_prefix.clear()
87 while iface_list:
88 self._tuntap(iface_list.pop())
89
90 def getFreeInterface(self, prefix):
91 try:
92 iface = self._free_iface_list.pop()
93 except IndexError:
94 iface = self._tuntap()
95 self._iface_to_prefix[iface] = prefix
96 return iface
97
98 def freeInterface(self, iface):
99 self._free_iface_list.append(iface)
100 del self._iface_to_prefix[iface]
101
102 def refresh(self):
103 logging.debug('Checking tunnels...')
104 self._cleanDeads()
105 remove = self._next_tunnel_refresh < time.time()
106 if remove:
107 self._countRoutes()
108 self._removeSomeTunnels()
109 self._next_tunnel_refresh = time.time() + self._refresh_time
110 self._peer_db.log()
111 self._makeNewTunnels(remove)
112 if remove and self._free_iface_list:
113 self._tuntap(self._free_iface_list.pop())
114 self.next_refresh = time.time() + 5
115
116 def _cleanDeads(self):
117 for prefix in self._connection_dict.keys():
118 if not self._connection_dict[prefix].refresh():
119 self._kill(prefix)
120
121 def _removeSomeTunnels(self):
122 # Get the candidates to killing
123 candidates = sorted(self._connection_dict, key=lambda p:
124 self._connection_dict[p].routes)
125 for prefix in candidates[0: max(0, len(self._connection_dict) -
126 self._client_count + self._refresh_count)]:
127 self._kill(prefix)
128
129 def _kill(self, prefix):
130 logging.info('Killing the connection with %u/%u...',
131 int(prefix, 2), len(prefix))
132 connection = self._connection_dict.pop(prefix)
133 self.freeInterface(connection.iface)
134 try:
135 connection.process.stop()
136 except OSError:
137 pass # we already polled an exited process
138 logging.trace('Connection with %u/%u killed',
139 int(prefix, 2), len(prefix))
140
141 def _makeTunnel(self, prefix, address):
142 assert len(self._connection_dict) < self._client_count, (prefix, self.__dict__)
143 if prefix in self._served or prefix in self._connection_dict:
144 return False
145 assert prefix != self._prefix, self.__dict__
146 logging.info('Establishing a connection with %u/%u',
147 int(prefix, 2), len(prefix))
148 iface = self.getFreeInterface(prefix)
149 self._connection_dict[prefix] = Connection(address, self._write_pipe,
150 self._timeout, iface, prefix, self._encrypt, self._ovpn_args)
151 self._peer_db.connecting(prefix, 1)
152 return True
153
154 def _makeNewTunnels(self, route_counted):
155 count = self._client_count - len(self._connection_dict)
156 if not count:
157 return
158 assert count >= 0
159 # CAVEAT: Forget any peer that didn't reply to our previous address
160 # request, either because latency is too high or some packet
161 # was lost. However, this means that some time should pass
162 # before calling _makeNewTunnels again.
163 self._connecting.clear()
164 distant_peers = self._distant_peers
165 if len(distant_peers) < count and not route_counted:
166 self._countRoutes()
167 disconnected = self._disconnected
168 if disconnected is not None:
169 # We aren't the registry node and we have no tunnel to or from it,
170 # so it looks like we are not connected to the network, and our
171 # neighbours are in the same situation.
172 self._disconnected = None
173 disconnected = set(disconnected).union(distant_peers)
174 if disconnected:
175 # We do have neighbours that are probably also disconnected,
176 # so force rebootstrapping.
177 peer = self._peer_db.getBootstrapPeer()
178 if not peer:
179 # Registry dead ? Assume we're connected after all.
180 disconnected = None
181 elif peer[0] not in disconnected:
182 # Got a node that will probably help us rejoining the
183 # network, so connect to it.
184 count -= self._makeTunnel(*peer)
185 if disconnected is None:
186 # Normal operation. Choose peers to connect to by looking at the
187 # routing table.
188 while count and distant_peers:
189 i = random.randrange(0, len(distant_peers))
190 peer = distant_peers[i]
191 distant_peers[i] = distant_peers[-1]
192 del distant_peers[-1]
193 address = self._peer_db.getAddress(peer)
194 if address:
195 count -= self._makeTunnel(peer, address)
196 else:
197 ip = utils.ipFromBin(self._network + peer)
198 # TODO: Send at least 1 address. This helps the registry
199 # node filling its cache when building a new network.
200 try:
201 self.sock.sendto('\2', (ip, PORT))
202 except socket.error, e:
203 logging.info('Failed to query %s (%s)', ip, e)
204 self._connecting.add(peer)
205 count -= 1
206 elif count:
207 # No route/tunnel to registry, which usually happens when starting
208 # up. Select peers from cache for which we have no route.
209 for peer, address in self._peer_db.getPeerList():
210 if peer not in disconnected and self._makeTunnel(peer, address):
211 count -= 1
212 if not count:
213 break
214 else:
215 if not (disconnected or self._served or self._connection_dict):
216 # Startup without any good address in the cache.
217 peer = self._peer_db.getBootstrapPeer()
218 if not (peer and self._makeTunnel(*peer)):
219 # Failed to bootstrap ! Last change to connect is to
220 # retry an address that already failed :(
221 for peer in self._peer_db.getPeerList(1):
222 if self._makeTunnel(*peer):
223 break
224
225 def _countRoutes(self):
226 logging.debug('Starting to count the routes on each interface...')
227 del self._distant_peers[:]
228 for conn in self._connection_dict.itervalues():
229 conn.routes = 0
230 a = len(self._network)
231 b = a + len(self._prefix)
232 other = []
233 with open('/proc/net/ipv6_route') as f:
234 self._last_routing_table = f.read()
235 for line in self._last_routing_table.splitlines():
236 line = line.split()
237 iface = line[-1]
238 if iface == 'lo' or int(line[-2], 16) & RTF_CACHE:
239 continue
240 ip = bin(int(line[0], 16))[2:].rjust(128, '0')
241 if ip[:a] != self._network or ip[a:b] == self._prefix:
242 continue
243 prefix_len = int(line[1], 16)
244 prefix = ip[a:prefix_len]
245 logging.trace('Route on iface %s detected to %s/%u',
246 iface, utils.ipFromBin(ip), prefix_len)
247 nexthop = self._iface_to_prefix.get(iface)
248 if nexthop:
249 self._connection_dict[nexthop].routes += 1
250 if prefix in self._served or prefix in self._connection_dict:
251 continue
252 if iface in self._iface_list:
253 other.append(prefix)
254 else:
255 self._distant_peers.append(prefix)
256 is_registry = self._peer_db.registry_ip[a:].startswith
257 if is_registry(self._prefix) or any(is_registry(peer)
258 for peer in chain(self._distant_peers, other,
259 self._served, self._connection_dict)):
260 self._disconnected = None
261 # XXX: When there is no new peer to connect when looking at routes
262 # coming from tunnels, we'd like to consider those discovered
263 # from the LAN. However, we don't want to create tunnels to
264 # nodes of the LAN so do nothing until we find a way to get
265 # some information from Babel.
266 #if not self._distant_peers:
267 # self._distant_peers = other
268 else:
269 self._disconnected = other
270 logging.debug("Routes counted: %u distant peers",
271 len(self._distant_peers))
272 for c in self._connection_dict.itervalues():
273 logging.trace('- %s: %s', c.iface, c.routes)
274
275 def killAll(self):
276 for prefix in self._connection_dict.keys():
277 self._kill(prefix)
278
279 def handleTunnelEvent(self, msg):
280 try:
281 msg = msg.rstrip()
282 args = msg.split()
283 m = getattr(self, '_ovpn_' + args.pop(0).replace('-', '_'))
284 except (AttributeError, ValueError):
285 logging.warning("Unknown message received from OpenVPN: %s", msg)
286 else:
287 logging.debug(msg)
288 m(*args)
289
290 def _ovpn_client_connect(self, common_name):
291 prefix = utils.binFromSubnet(common_name)
292 self._served.add(prefix)
293 if prefix in self._connection_dict and self._prefix < prefix:
294 self._kill(prefix)
295 self._peer_db.connecting(prefix, 0)
296
297 def _ovpn_client_disconnect(self, common_name):
298 prefix = utils.binFromSubnet(common_name)
299 self._served.discard(prefix)
300
301 def _ovpn_route_up(self, common_name, ip):
302 self._peer_db.connecting(utils.binFromSubnet(common_name), 0)
303 if self._ip_changed:
304 self._address = utils.address_str(self._ip_changed(ip))
305
306 def handlePeerEvent(self):
307 msg, address = self.sock.recvfrom(1<<16)
308 if not (msg or utils.binFromIp(address[0]).startswith(self._network)):
309 return
310 code = ord(msg[0])
311 if code == 1: # answer
312 # TODO: do not fail if message contains garbage
313 # We parse the message in a way to discard a truncated line.
314 for peer in msg[1:].split('\n')[:-1]:
315 prefix, address = peer.split()
316 if prefix != self._prefix:
317 self._peer_db.addPeer(prefix, address)
318 try:
319 self._connecting.remove(prefix)
320 except KeyError:
321 continue
322 self._makeTunnel(prefix, address)
323 elif code == 2: # request
324 encode = '%s %s\n'.__mod__
325 if self._address:
326 msg = [encode((self._prefix, self._address))]
327 else: # I don't know my IP yet!
328 msg = []
329 # Add an extra random peer, mainly for the registry.
330 if random.randint(0, self._peer_db.getPeerCount()):
331 msg.append(encode(self._peer_db.getPeerList().next()))
332 if msg:
333 try:
334 self.sock.sendto('\1' + ''.join(msg), address)
335 except socket.error, e:
336 logging.info('Failed to reply to %s (%s)', address, e)
337 elif code == 255:
338 # the registry wants to know the topology for debugging purpose
339 if utils.binFromIp(address[0]) == self._peer_db.registry_ip:
340 msg = ['\xfe%s%u/%u\n%u\n' % (msg[1:],
341 int(self._prefix, 2), len(self._prefix),
342 len(self._connection_dict))]
343 msg.extend('%u/%u\n' % (int(x, 2), len(x))
344 for x in (self._connection_dict, self._served)
345 for x in x)
346 try:
347 self.sock.sendto(''.join(msg), address)
348 except socket.error, e:
349 pass