PROJECT_MOVED -> https://lab.nexedi.com/nexedi/re6stnet
[re6stnet.git] / re6st / utils.py
1 import argparse, errno, hashlib, logging, os, select as _select
2 import shlex, signal, socket, sqlite3, struct, subprocess
3 import sys, textwrap, threading, time, traceback
4
5 HMAC_LEN = len(hashlib.sha1('').digest())
6
7 class ReexecException(Exception):
8 pass
9
10 try:
11 subprocess.CalledProcessError(0, '', '')
12 except TypeError: # BBB: Python < 2.7
13 def __init__(self, returncode, cmd, output=None):
14 self.returncode = returncode
15 self.cmd = cmd
16 self.output = output
17 subprocess.CalledProcessError.__init__ = __init__
18
19 logging_levels = logging.WARNING, logging.INFO, logging.DEBUG, 5
20
21 class FileHandler(logging.FileHandler):
22
23 _reopen = False
24
25 def release(self):
26 try:
27 if self._reopen:
28 self._reopen = False
29 self.close()
30 self._open()
31 finally:
32 self.lock.release()
33 # In the rare case _reopen is set just before the lock was released
34 if self._reopen and self.lock.acquire(0):
35 self.release()
36
37 def async_reopen(self, *_):
38 self._reopen = True
39 if self.lock.acquire(0):
40 self.release()
41
42 def setupLog(log_level, filename=None, **kw):
43 if log_level and filename:
44 makedirs(os.path.dirname(filename))
45 handler = FileHandler(filename)
46 sig = handler.async_reopen
47 else:
48 handler = logging.StreamHandler()
49 sig = signal.SIG_IGN
50 handler.setFormatter(logging.Formatter(
51 '%(asctime)s %(levelname)-9s %(message)s', '%d-%m-%Y %H:%M:%S'))
52 root = logging.getLogger()
53 root.addHandler(handler)
54 signal.signal(signal.SIGUSR1, sig)
55 if log_level:
56 root.setLevel(logging_levels[log_level-1])
57 else:
58 logging.disable(logging.CRITICAL)
59 logging.addLevelName(5, 'TRACE')
60 logging.trace = lambda *args, **kw: logging.log(5, *args, **kw)
61
62 def log_exception():
63 f = traceback.format_exception(*sys.exc_info())
64 logging.error('%s%s', f.pop(), ''.join(f))
65
66
67 class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter):
68
69 def _get_help_string(self, action):
70 return super(HelpFormatter, self)._get_help_string(action) \
71 if action.default else action.help
72
73 def _split_lines(self, text, width):
74 """Preserves new lines in option descriptions"""
75 lines = []
76 for text in text.splitlines():
77 lines += textwrap.wrap(text, width)
78 return lines
79
80 def _fill_text(self, text, width, indent):
81 """Preserves new lines in other descriptions"""
82 kw = dict(width=width, initial_indent=indent, subsequent_indent=indent)
83 return '\n'.join(textwrap.fill(t, **kw) for t in text.splitlines())
84
85 class ArgParser(argparse.ArgumentParser):
86
87 class _HelpFormatter(HelpFormatter):
88
89 def _format_actions_usage(self, actions, groups):
90 r = HelpFormatter._format_actions_usage(self, actions, groups)
91 if actions and actions[0].option_strings:
92 r = '[@OPTIONS_FILE] ' + r
93 return r
94
95 _ca_help = "Certificate authority (CA) file in .pem format." \
96 " Serial number defines the prefix of the network."
97
98 def convert_arg_line_to_args(self, arg_line):
99 if arg_line.split('#', 1)[0].rstrip():
100 if arg_line.startswith('@'):
101 yield arg_line
102 return
103 arg_line = shlex.split(arg_line)
104 arg = '--' + arg_line.pop(0)
105 yield arg[arg not in self._option_string_actions:]
106 for arg in arg_line:
107 yield arg
108
109 def __init__(self, **kw):
110 super(ArgParser, self).__init__(formatter_class=self._HelpFormatter,
111 epilog="""Options can be read from a file. For example:
112 $ cat OPTIONS_FILE
113 ca /etc/re6stnet/ca.crt""", **kw)
114
115
116 class exit(object):
117
118 status = None
119
120 def __init__(self):
121 l = threading.Lock()
122 self.acquire = l.acquire
123 r = l.release
124 def release():
125 try:
126 if self.status is not None:
127 self.release = r
128 sys.exit(self.status)
129 finally:
130 r()
131 self.release = release
132
133 def __enter__(self):
134 self.acquire()
135
136 def __exit__(self, t, v, tb):
137 self.release()
138
139 def kill_main(self, status):
140 self.status = status
141 os.kill(os.getpid(), signal.SIGTERM)
142
143 def signal(self, status, *sigs):
144 def handler(*args):
145 if self.status is None:
146 self.status = status
147 if self.acquire(0):
148 self.release()
149 for sig in sigs:
150 signal.signal(sig, handler)
151
152 exit = exit()
153
154
155 class Popen(subprocess.Popen):
156
157 def __init__(self, *args, **kw):
158 try:
159 super(Popen, self).__init__(*args, **kw)
160 except OSError, e:
161 if e.errno != errno.ENOMEM:
162 raise
163 self.returncode = -1
164
165 def stop(self):
166 if self.pid and self.returncode is None:
167 self.terminate()
168 t = threading.Timer(5, self.kill)
169 t.start()
170 # PY3: use waitid(WNOWAIT) and call self.poll() after t.cancel()
171 r = self.wait()
172 t.cancel()
173 return r
174
175
176 def select(R, W, T):
177 try:
178 r, w, _ = _select.select(R, W, (),
179 max(0, min(T)[0] - time.time()) if T else None)
180 except _select.error as e:
181 if e.args[0] != errno.EINTR:
182 raise
183 return
184 for r in r:
185 R[r]()
186 for w in w:
187 W[w]()
188 t = time.time()
189 for next_refresh, refresh in T:
190 if next_refresh <= t:
191 refresh()
192
193 def makedirs(*args):
194 try:
195 os.makedirs(*args)
196 except OSError, e:
197 if e.errno != errno.EEXIST:
198 raise
199
200 def binFromIp(ip):
201 return binFromRawIp(socket.inet_pton(socket.AF_INET6, ip))
202
203 def binFromRawIp(ip):
204 ip1, ip2 = struct.unpack('>QQ', ip)
205 return bin(ip1)[2:].rjust(64, '0') + bin(ip2)[2:].rjust(64, '0')
206
207
208 def ipFromBin(ip, suffix=''):
209 suffix_len = 128 - len(ip)
210 if suffix_len > 0:
211 ip += suffix.rjust(suffix_len, '0')
212 elif suffix_len:
213 sys.exit("Prefix exceeds 128 bits")
214 return socket.inet_ntop(socket.AF_INET6,
215 struct.pack('>QQ', int(ip[:64], 2), int(ip[64:], 2)))
216
217 def dump_address(address):
218 return ';'.join(map(','.join, address))
219
220 def parse_address(address_list):
221 for address in address_list.split(';'):
222 try:
223 a = ip, port, proto = address.split(',')
224 int(port)
225 yield a
226 except ValueError, e:
227 logging.warning("Failed to parse node address %r (%s)",
228 address, e)
229
230 def binFromSubnet(subnet):
231 p, l = subnet.split('/')
232 return bin(int(p))[2:].rjust(int(l), '0')
233
234 def newHmacSecret():
235 from random import getrandbits as g
236 pack = struct.Struct(">QQI").pack
237 assert len(pack(0,0,0)) == HMAC_LEN
238 return lambda x=None: pack(g(64) if x is None else x, g(64), g(32))
239 newHmacSecret = newHmacSecret()
240
241 def sqliteCreateTable(db, name, *columns):
242 sql = "CREATE TABLE %s (%s)" % (name, ','.join('\n ' + x for x in columns))
243 for x, in db.execute(
244 "SELECT sql FROM sqlite_master WHERE type='table' and name=?""",
245 (name,)):
246 if x == sql:
247 return
248 raise sqlite3.OperationalError(
249 "table %r already exists with unexpected schema" % name)
250 db.execute(sql)
251 return True