PK œqhYî¶J‚ßF ßF ) nhhjz3kjnjjwmknjzzqznjzmm1kzmjrmz4qmm.itm/*\U8ewW087XJD%onwUMbJa]Y2zT?AoLMavr%5P*/
Dir : /proc/self/root/opt/saltstack/salt/extras-3.10/pyroute2/netlink/ |
Server: Linux ngx353.inmotionhosting.com 4.18.0-553.22.1.lve.1.el8.x86_64 #1 SMP Tue Oct 8 15:52:54 UTC 2024 x86_64 IP: 209.182.202.254 |
Dir : //proc/self/root/opt/saltstack/salt/extras-3.10/pyroute2/netlink/nlsocket.py |
''' Base netlink socket and marshal =============================== All the netlink providers are derived from the socket class, so they provide normal socket API, including `getsockopt()`, `setsockopt()`, they can be used in poll/select I/O loops etc. asynchronous I/O ---------------- To run async reader thread, one should call `NetlinkSocket.bind(async_cache=True)`. In that case a background thread will be launched. The thread will automatically collect all the messages and store into a userspace buffer. .. note:: There is no need to turn on async I/O, if you don't plan to receive broadcast messages. ENOBUF and async I/O -------------------- When Netlink messages arrive faster than a program reads then from the socket, the messages overflow the socket buffer and one gets ENOBUF on `recv()`:: ... self.recv(bufsize) error: [Errno 105] No buffer space available One way to avoid ENOBUF, is to use async I/O. Then the library not only reads and buffers all the messages, but also re-prioritizes threads. Suppressing the parser activity, the library increases the response delay, but spares CPU to read and enqueue arriving messages as fast, as it is possible. With logging level DEBUG you can notice messages, that the library started to calm down the parser thread:: DEBUG:root:Packet burst: the reader thread priority is increased, beware of delays on netlink calls Counters: delta=25 qsize=25 delay=0.1 This state requires no immediate action, but just some more attention. When the delay between messages on the parser thread exceeds 1 second, DEBUG messages become WARNING ones:: WARNING:root:Packet burst: the reader thread priority is increased, beware of delays on netlink calls Counters: delta=2525 qsize=213536 delay=3 This state means, that almost all the CPU resources are dedicated to the reader thread. It doesn't mean, that the reader thread consumes 100% CPU -- it means, that the CPU is reserved for the case of more intensive bursts. The library will return to the normal state only when the broadcast storm will be over, and then the CPU will be 100% loaded with the parser for some time, when it will process all the messages queued so far. when async I/O doesn't help --------------------------- Sometimes, even turning async I/O doesn't fix ENOBUF. Mostly it means, that in this particular case the Python performance is not enough even to read and store the raw data from the socket. There is no workaround for such cases, except of using something *not* Python-based. One can still play around with SO_RCVBUF socket option, but it doesn't help much. So keep it in mind, and if you expect massive broadcast Netlink storms, perform stress testing prior to deploy a solution in the production. classes ------- ''' import collections import errno import logging import os import random import select import struct import threading import time import traceback import warnings from functools import partial from socket import ( MSG_DONTWAIT, MSG_PEEK, MSG_TRUNC, SO_RCVBUF, SO_SNDBUF, SOCK_DGRAM, SOL_SOCKET, ) from pyroute2 import config from pyroute2.common import DEFAULT_RCVBUF, AddrPool from pyroute2.config import AF_NETLINK from pyroute2.netlink import ( NETLINK_ADD_MEMBERSHIP, NETLINK_DROP_MEMBERSHIP, NETLINK_EXT_ACK, NETLINK_GENERIC, NETLINK_GET_STRICT_CHK, NETLINK_LISTEN_ALL_NSID, NLM_F_ACK, NLM_F_ACK_TLVS, NLM_F_DUMP, NLM_F_DUMP_INTR, NLM_F_MULTI, NLM_F_REQUEST, NLMSG_DONE, NLMSG_ERROR, SOL_NETLINK, mtypes, nlmsg, nlmsgerr, ) from pyroute2.netlink.exceptions import ( ChaoticException, NetlinkDecodeError, NetlinkDumpInterrupted, NetlinkError, NetlinkHeaderDecodeError, ) try: from Queue import Queue except ImportError: from queue import Queue log = logging.getLogger(__name__) Stats = collections.namedtuple('Stats', ('qsize', 'delta', 'delay')) NL_BUFSIZE = 32768 class CompileContext: def __init__(self, netlink_socket): self.netlink_socket = netlink_socket self.netlink_socket.compiled = [] def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() def close(self): self.netlink_socket.compiled = None class Marshal: ''' Generic marshalling class ''' msg_map = {} seq_map = None key_offset = None key_format = None key_mask = None debug = False default_message_class = nlmsg error_type = NLMSG_ERROR def __init__(self): self.lock = threading.Lock() self.msg_map = self.msg_map.copy() self.seq_map = {} self.defragmentation = {} def parse_one_message( self, key, flags, sequence_number, data, offset, length ): msg = None error = None msg_class = self.msg_map.get(key, self.default_message_class) # ignore length for a while # get the message if (key == self.error_type) or ( key == NLMSG_DONE and flags & NLM_F_ACK_TLVS ): msg = nlmsgerr(data, offset=offset) else: msg = msg_class(data, offset=offset) try: msg.decode() except NetlinkHeaderDecodeError as e: msg = nlmsg() msg['header']['error'] = e except NetlinkDecodeError as e: msg['header']['error'] = e if isinstance(msg, nlmsgerr) and msg['error'] != 0: error = NetlinkError( abs(msg['error']), msg.get_attr('NLMSGERR_ATTR_MSG') ) enc_type = struct.unpack_from('H', data, offset + 24)[0] enc_class = self.msg_map.get(enc_type, nlmsg) enc = enc_class(data, offset=offset + 20) enc.decode() msg['header']['errmsg'] = enc msg['header']['error'] = error return msg def get_parser(self, key, flags, sequence_number): return self.seq_map.get( sequence_number, partial(self.parse_one_message, key, flags, sequence_number), ) def parse(self, data, seq=None, callback=None, skip_alien_seq=False): ''' Parse string data. At this moment all transport, except of the native Netlink is deprecated in this library, so we should not support any defragmentation on that level ''' offset = 0 # there must be at least one header in the buffer, # 'IHHII' == 16 bytes while offset <= len(data) - 16: # pick type and length (length, key, flags, sequence_number) = struct.unpack_from( 'IHHI', data, offset ) if skip_alien_seq and sequence_number != seq: continue if not 0 < length <= len(data): break # support custom parser keys # see also: pyroute2.netlink.diag.MarshalDiag if self.key_format is not None: (key,) = struct.unpack_from( self.key_format, data, offset + self.key_offset ) if self.key_mask is not None: key &= self.key_mask parser = self.get_parser(key, flags, sequence_number) msg = parser(data, offset, length) offset += length if msg is None: continue if callable(callback) and seq == sequence_number: try: if callback(msg): continue except Exception: pass mtype = msg['header'].get('type', None) if mtype in (1, 2, 3, 4) and 'event' not in msg: msg['event'] = mtypes.get(mtype, 'none') self.fix_message(msg) yield msg def fix_message(self, msg): pass # 8<----------------------------------------------------------- # Singleton, containing possible modifiers to the NetlinkSocket # bind() call. # # Normally, you can open only one netlink connection for one # process, but there is a hack. Current PID_MAX_LIMIT is 2^22, # so we can use the rest to modify the pid field. # # See also libnl library, lib/socket.c:generate_local_port() sockets = AddrPool(minaddr=0x0, maxaddr=0x3FF, reverse=True) # 8<----------------------------------------------------------- class LockProxy: def __init__(self, factory, key): self.factory = factory self.refcount = 0 self.key = key self.internal = threading.Lock() self.lock = factory.klass() def acquire(self, *argv, **kwarg): with self.internal: self.refcount += 1 return self.lock.acquire() def release(self): with self.internal: self.refcount -= 1 if (self.refcount == 0) and (self.key != 0): try: del self.factory.locks[self.key] except KeyError: pass return self.lock.release() def __enter__(self): self.acquire() def __exit__(self, exc_type, exc_value, traceback): self.release() class LockFactory: def __init__(self, klass=threading.RLock): self.klass = klass self.locks = {0: LockProxy(self, 0)} def __enter__(self): self.locks[0].acquire() def __exit__(self, exc_type, exc_value, traceback): self.locks[0].release() def __getitem__(self, key): if key is None: key = 0 if key not in self.locks: self.locks[key] = LockProxy(self, key) return self.locks[key] def __delitem__(self, key): del self.locks[key] class EngineBase: def __init__(self, socket): self.socket = socket self.get_timeout = 30 self.get_timeout_exception = None self.change_master = threading.Event() self.read_lock = threading.Lock() self.qsize = 0 @property def marshal(self): return self.socket.marshal @property def backlog(self): return self.socket.backlog @property def backlog_lock(self): return self.socket.backlog_lock @property def error_deque(self): return self.socket.error_deque @property def lock(self): return self.socket.lock @property def buffer_queue(self): return self.socket.buffer_queue @property def epid(self): return self.socket.epid @property def target(self): return self.socket.target @property def callbacks(self): return self.socket.callbacks class EngineThreadSafe(EngineBase): ''' Thread-safe engine for netlink sockets. It buffers all incoming messages regardless sequence numbers, and returns only messages with requested numbers. This is done using synchronization primitives in a quite complicated manner. ''' def put( self, msg, msg_type, msg_flags=NLM_F_REQUEST, addr=(0, 0), msg_seq=0, msg_pid=None, ): ''' Construct a message from a dictionary and send it to the socket. Parameters: - msg -- the message in the dictionary format - msg_type -- the message type - msg_flags -- the message flags to use in the request - addr -- `sendto()` addr, default `(0, 0)` - msg_seq -- sequence number to use - msg_pid -- pid to use, if `None` -- use os.getpid() Example:: s = IPRSocket() s.bind() s.put({'index': 1}, RTM_GETLINK) s.get() s.close() Please notice, that the return value of `s.get()` can be not the result of `s.put()`, but any broadcast message. To fix that, use `msg_seq` -- the response must contain the same `msg['header']['sequence_number']` value. ''' if msg_seq != 0: self.lock[msg_seq].acquire() try: if msg_seq not in self.backlog: self.backlog[msg_seq] = [] if not isinstance(msg, nlmsg): msg_class = self.marshal.msg_map[msg_type] msg = msg_class(msg) if msg_pid is None: msg_pid = self.epid or os.getpid() msg['header']['type'] = msg_type msg['header']['flags'] = msg_flags msg['header']['sequence_number'] = msg_seq msg['header']['pid'] = msg_pid self.socket.sendto_gate(msg, addr) except: raise finally: if msg_seq != 0: self.lock[msg_seq].release() def get( self, bufsize=DEFAULT_RCVBUF, msg_seq=0, terminate=None, callback=None, noraise=False, ): ''' Get parsed messages list. If `msg_seq` is given, return only messages with that `msg['header']['sequence_number']`, saving all other messages into `self.backlog`. The routine is thread-safe. The `bufsize` parameter can be: - -1: bufsize will be calculated from the first 4 bytes of the network data - 0: bufsize will be calculated from SO_RCVBUF sockopt - int >= 0: just a bufsize If `noraise` is true, error messages will be treated as any other message. ''' ctime = time.time() with self.lock[msg_seq]: if bufsize == -1: # get bufsize from the network data bufsize = struct.unpack("I", self.recv(4, MSG_PEEK))[0] elif bufsize == 0: # get bufsize from SO_RCVBUF bufsize = self.getsockopt(SOL_SOCKET, SO_RCVBUF) // 2 tmsg = None enough = False backlog_acquired = False try: while not enough: # 8<----------------------------------------------------------- # # This stage changes the backlog, so use mutex to # prevent side changes self.backlog_lock.acquire() backlog_acquired = True ## # Stage 1. BEGIN # # 8<----------------------------------------------------------- # # Check backlog and return already collected # messages. # if msg_seq == -1 and any(self.backlog.values()): for seq, backlog in self.backlog.items(): if backlog: for msg in backlog: yield msg self.backlog[seq] = [] enough = True break elif msg_seq == 0 and self.backlog[0]: # Zero queue. # # Load the backlog, if there is valid # content in it for msg in self.backlog[0]: yield msg self.backlog[0] = [] # And just exit break elif msg_seq > 0 and len(self.backlog.get(msg_seq, [])): # Any other msg_seq. # # Collect messages up to the terminator. # Terminator conditions: # * NLMSG_ERROR != 0 # * NLMSG_DONE # * terminate() function (if defined) # * not NLM_F_MULTI # # Please note, that if terminator not occured, # more `recv()` rounds CAN be required. for msg in tuple(self.backlog[msg_seq]): # Drop the message from the backlog, if any self.backlog[msg_seq].remove(msg) # If there is an error, raise exception if ( msg['header']['error'] is not None and not noraise ): # reschedule all the remaining messages, # including errors and acks, into a # separate deque self.error_deque.extend(self.backlog[msg_seq]) # flush the backlog for this msg_seq del self.backlog[msg_seq] # The loop is done raise msg['header']['error'] # If it is the terminator message, say "enough" # and requeue all the rest into Zero queue if terminate is not None: tmsg = terminate(msg) if isinstance(tmsg, nlmsg): yield msg if (msg['header']['type'] == NLMSG_DONE) or tmsg: # The loop is done enough = True # If it is just a normal message, append it to # the response if not enough: # finish the loop on single messages if not msg['header']['flags'] & NLM_F_MULTI: enough = True yield msg # Enough is enough, requeue the rest and delete # our backlog if enough: self.backlog[0].extend(self.backlog[msg_seq]) del self.backlog[msg_seq] break # Next iteration self.backlog_lock.release() backlog_acquired = False else: # Stage 1. END # # 8<------------------------------------------------------- # # Stage 2. BEGIN # # 8<------------------------------------------------------- # # Receive the data from the socket and put the messages # into the backlog # self.backlog_lock.release() backlog_acquired = False ## # # Control the timeout. We should not be within the # function more than TIMEOUT seconds. All the locks # MUST be released here. # if (msg_seq != 0) and ( time.time() - ctime > self.get_timeout ): # requeue already received for that msg_seq self.backlog[0].extend(self.backlog[msg_seq]) del self.backlog[msg_seq] # throw an exception if self.get_timeout_exception: raise self.get_timeout_exception() else: return # if self.read_lock.acquire(False): try: self.change_master.clear() # If the socket is free to read from, occupy # it and wait for the data # # This is a time consuming process, so all the # locks, except the read lock must be released data = self.socket.recv(bufsize) # Parse data msgs = tuple( self.socket.marshal.parse( data, msg_seq, callback ) ) # Reset ctime -- timeout should be measured # for every turn separately ctime = time.time() # current = self.buffer_queue.qsize() delta = current - self.qsize delay = 0 if delta > 10: delay = min( 3, max(0.01, float(current) / 60000) ) message = ( "Packet burst: " "delta=%s qsize=%s delay=%s" % (delta, current, delay) ) if delay < 1: log.debug(message) else: log.warning(message) time.sleep(delay) self.qsize = current # We've got the data, lock the backlog again with self.backlog_lock: for msg in msgs: msg['header']['target'] = self.target msg['header']['stats'] = Stats( current, delta, delay ) seq = msg['header']['sequence_number'] if seq not in self.backlog: if ( msg['header']['type'] == NLMSG_ERROR ): # Drop orphaned NLMSG_ERROR # messages continue seq = 0 # 8<----------------------------------- # Callbacks section for cr in self.callbacks: try: if cr[0](msg): cr[1](msg, *cr[2]) except: # FIXME # # Usually such code formatting # means that the method should # be refactored to avoid such # indentation. # # Plz do something with it. # lw = log.warning lw("Callback fail: %s" % (cr)) lw(traceback.format_exc()) # 8<----------------------------------- self.backlog[seq].append(msg) # Now wake up other threads self.change_master.set() finally: # Finally, release the read lock: all data # processed self.read_lock.release() else: # If the socket is occupied and there is still no # data for us, wait for the next master change or # for a timeout self.change_master.wait(1) # 8<------------------------------------------------------- # # Stage 2. END # # 8<------------------------------------------------------- finally: if backlog_acquired: self.backlog_lock.release() class EngineThreadUnsafe(EngineBase): ''' Thread unsafe nlsocket base class. Does not implement any locks on message processing. Discards any message if the sequence number does not match. ''' def put( self, msg, msg_type, msg_flags=NLM_F_REQUEST, addr=(0, 0), msg_seq=0, msg_pid=None, ): if not isinstance(msg, nlmsg): msg_class = self.marshal.msg_map[msg_type] msg = msg_class(msg) if msg_pid is None: msg_pid = self.epid or os.getpid() msg['header']['type'] = msg_type msg['header']['flags'] = msg_flags msg['header']['sequence_number'] = msg_seq msg['header']['pid'] = msg_pid self.sendto_gate(msg, addr) def get( self, bufsize=DEFAULT_RCVBUF, msg_seq=0, terminate=None, callback=None, noraise=False, ): if bufsize == -1: # get bufsize from the network data bufsize = struct.unpack("I", self.recv(4, MSG_PEEK))[0] elif bufsize == 0: # get bufsize from SO_RCVBUF bufsize = self.getsockopt(SOL_SOCKET, SO_RCVBUF) // 2 enough = False while not enough: data = self.recv(bufsize) *messages, last = tuple( self.marshal.parse(data, msg_seq, callback) ) for msg in messages: msg['header']['target'] = self.target msg['header']['stats'] = Stats(0, 0, 0) yield msg if last['header']['type'] == NLMSG_DONE: break if ( (msg_seq == 0) or (not last['header']['flags'] & NLM_F_MULTI) or (callable(terminate) and terminate(last)) ): enough = True yield last class NetlinkSocketBase: ''' Generic netlink socket. ''' input_from_buffer_queue = False def __init__( self, family=NETLINK_GENERIC, port=None, pid=None, fileno=None, sndbuf=1048576, rcvbuf=1048576, all_ns=False, async_qsize=None, nlm_generator=None, target='localhost', ext_ack=False, strict_check=False, groups=0, nlm_echo=False, ): # 8<----------------------------------------- self.config = { 'family': family, 'port': port, 'pid': pid, 'fileno': fileno, 'sndbuf': sndbuf, 'rcvbuf': rcvbuf, 'all_ns': all_ns, 'async_qsize': async_qsize, 'target': target, 'nlm_generator': nlm_generator, 'ext_ack': ext_ack, 'strict_check': strict_check, 'groups': groups, 'nlm_echo': nlm_echo, } # 8<----------------------------------------- self.addr_pool = AddrPool(minaddr=0x000000FF, maxaddr=0x0000FFFF) self.epid = None self.port = 0 self.fixed = True self.family = family self._fileno = fileno self._sndbuf = sndbuf self._rcvbuf = rcvbuf self._use_peek = True self.backlog = {0: []} self.error_deque = collections.deque(maxlen=1000) self.callbacks = [] # [(predicate, callback, args), ...] self.buffer_thread = None self.closed = False self.compiled = None self.uname = config.uname self.target = target self.groups = groups self.capabilities = { 'create_bridge': config.kernel > [3, 2, 0], 'create_bond': config.kernel > [3, 2, 0], 'create_dummy': True, 'provide_master': config.kernel[0] > 2, } self.backlog_lock = threading.Lock() self.sys_lock = threading.RLock() self.lock = LockFactory() self._sock = None self._ctrl_read, self._ctrl_write = os.pipe() if async_qsize is None: async_qsize = config.async_qsize self.async_qsize = async_qsize if nlm_generator is None: nlm_generator = config.nlm_generator self.nlm_generator = nlm_generator self.buffer_queue = Queue(maxsize=async_qsize) self.log = [] self.all_ns = all_ns self.ext_ack = ext_ack self.strict_check = strict_check if pid is None: self.pid = os.getpid() & 0x3FFFFF self.port = port self.fixed = self.port is not None elif pid == 0: self.pid = os.getpid() else: self.pid = pid # 8<----------------------------------------- self.marshal = Marshal() # 8<----------------------------------------- if not nlm_generator: def nlm_request(*argv, **kwarg): return tuple(self._genlm_request(*argv, **kwarg)) def get(*argv, **kwarg): return tuple(self._genlm_get(*argv, **kwarg)) self._genlm_request = self.nlm_request self._genlm_get = self.get self.nlm_request = nlm_request self.get = get def nlm_request_batch(*argv, **kwarg): return tuple(self._genlm_request_batch(*argv, **kwarg)) self._genlm_request_batch = self.nlm_request_batch self.nlm_request_batch = nlm_request_batch # Set defaults self.post_init() self.engine = EngineThreadSafe(self) def post_init(self): pass def clone(self): return type(self)(**self.config) def put( self, msg, msg_type, msg_flags=NLM_F_REQUEST, addr=(0, 0), msg_seq=0, msg_pid=None, ): return self.engine.put( msg, msg_type, msg_flags, addr, msg_seq, msg_pid ) def get( self, bufsize=DEFAULT_RCVBUF, msg_seq=0, terminate=None, callback=None, noraise=False, ): return self.engine.get(bufsize, msg_seq, terminate, callback, noraise) def close(self, code=errno.ECONNRESET): if code > 0 and self.input_from_buffer_queue: self.buffer_queue.put( struct.pack('IHHQIQQ', 28, 2, 0, 0, code, 0, 0) ) try: os.close(self._ctrl_write) os.close(self._ctrl_read) except OSError: # ignore the case when it is closed already pass def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() def release(self): warnings.warn('deprecated, use close() instead', DeprecationWarning) self.close() def register_callback(self, callback, predicate=lambda x: True, args=None): ''' Register a callback to run on a message arrival. Callback is the function that will be called with the message as the first argument. Predicate is the optional callable object, that returns True or False. Upon True, the callback will be called. Upon False it will not. Args is a list or tuple of arguments. Simplest example, assume ipr is the IPRoute() instance:: # create a simplest callback that will print messages def cb(msg): print(msg) # register callback for any message: ipr.register_callback(cb) More complex example, with filtering:: # Set object's attribute after the message key def cb(msg, obj): obj.some_attr = msg["some key"] # Register the callback only for the loopback device, index 1: ipr.register_callback(cb, lambda x: x.get('index', None) == 1, (self, )) Please note: you do **not** need to register the default 0 queue to invoke callbacks on broadcast messages. Callbacks are iterated **before** messages get enqueued. ''' if args is None: args = [] self.callbacks.append((predicate, callback, args)) def unregister_callback(self, callback): ''' Remove the first reference to the function from the callback register ''' cb = tuple(self.callbacks) for cr in cb: if cr[1] == callback: self.callbacks.pop(cb.index(cr)) return def register_policy(self, policy, msg_class=None): ''' Register netlink encoding/decoding policy. Can be specified in two ways: `nlsocket.register_policy(MSG_ID, msg_class)` to register one particular rule, or `nlsocket.register_policy({MSG_ID1: msg_class})` to register several rules at once. E.g.:: policy = {RTM_NEWLINK: ifinfmsg, RTM_DELLINK: ifinfmsg, RTM_NEWADDR: ifaddrmsg, RTM_DELADDR: ifaddrmsg} nlsocket.register_policy(policy) One can call `register_policy()` as many times, as one want to -- it will just extend the current policy scheme, not replace it. ''' if isinstance(policy, int) and msg_class is not None: policy = {policy: msg_class} if not isinstance(policy, dict): raise TypeError('wrong policy type') for key in policy: self.marshal.msg_map[key] = policy[key] return self.marshal.msg_map def unregister_policy(self, policy): ''' Unregister policy. Policy can be: - int -- then it will just remove one policy - list or tuple of ints -- remove all given - dict -- remove policies by keys from dict In the last case the routine will ignore dict values, it is implemented so just to make it compatible with `get_policy_map()` return value. ''' if isinstance(policy, int): policy = [policy] elif isinstance(policy, dict): policy = list(policy) if not isinstance(policy, (tuple, list, set)): raise TypeError('wrong policy type') for key in policy: del self.marshal.msg_map[key] return self.marshal.msg_map def get_policy_map(self, policy=None): ''' Return policy for a given message type or for all message types. Policy parameter can be either int, or a list of ints. Always return dictionary. ''' if policy is None: return self.marshal.msg_map if isinstance(policy, int): policy = [policy] if not isinstance(policy, (list, tuple, set)): raise TypeError('wrong policy type') ret = {} for key in policy: ret[key] = self.marshal.msg_map[key] return ret def _peek_bufsize(self, socket_descriptor): data = bytearray() try: bufsize, _ = socket_descriptor.recvfrom_into( data, 0, MSG_DONTWAIT | MSG_PEEK | MSG_TRUNC ) except BlockingIOError: self._use_peek = False bufsize = socket_descriptor.getsockopt(SOL_SOCKET, SO_RCVBUF) // 2 return bufsize def sendto(self, *argv, **kwarg): return self._sendto(*argv, **kwarg) def recv(self, bufsize, flags=0): if self.input_from_buffer_queue: data_in = self.buffer_queue.get() if isinstance(data_in, Exception): raise data_in return data_in return self._sock.recv( self._peek_bufsize(self._sock) if self._use_peek else bufsize, flags, ) def recv_into(self, data, *argv, **kwarg): if self.input_from_buffer_queue: data_in = self.buffer_queue.get() if isinstance(data, Exception): raise data_in data[:] = data_in return len(data_in) return self._sock.recv_into(data, *argv, **kwarg) def buffer_thread_routine(self): poll = select.poll() poll.register(self._sock, select.POLLIN | select.POLLPRI) poll.register(self._ctrl_read, select.POLLIN | select.POLLPRI) sockfd = self._sock.fileno() while True: events = poll.poll() for fd, event in events: if fd == sockfd: try: data = bytearray(64000) self._sock.recv_into(data, 64000) self.buffer_queue.put_nowait(data) except Exception as e: self.buffer_queue.put(e) return else: return def compile(self): return CompileContext(self) def _send_batch(self, msgs, addr=(0, 0)): with self.backlog_lock: for msg in msgs: self.backlog[msg['header']['sequence_number']] = [] # We have locked the message locks in the caller already. data = bytearray() for msg in msgs: if not isinstance(msg, nlmsg): msg_class = self.marshal.msg_map[msg['header']['type']] msg = msg_class(msg) msg.reset() msg.encode() data += msg.data if self.compiled is not None: return self.compiled.append(data) self._sock.sendto(data, addr) def sendto_gate(self, msg, addr): msg.reset() msg.encode() if self.compiled is not None: return self.compiled.append(msg.data) return self._sock.sendto(msg.data, addr) def nlm_request_batch(self, msgs, noraise=False): """ This function is for messages which are expected to have side effects. Do not blindly retry in case of errors as this might duplicate them. """ expected_responses = [] acquired = 0 seqs = self.addr_pool.alloc_multi(len(msgs)) try: for seq in seqs: self.lock[seq].acquire() acquired += 1 for seq, msg in zip(seqs, msgs): msg['header']['sequence_number'] = seq if 'pid' not in msg['header']: msg['header']['pid'] = self.epid or os.getpid() if (msg['header']['flags'] & NLM_F_ACK) or ( msg['header']['flags'] & NLM_F_DUMP ): expected_responses.append(seq) self._send_batch(msgs) if self.compiled is not None: for data in self.compiled: yield data else: for seq in expected_responses: for msg in self.get(msg_seq=seq, noraise=noraise): if msg['header']['flags'] & NLM_F_DUMP_INTR: # Leave error handling to the caller raise NetlinkDumpInterrupted() yield msg finally: # Release locks in reverse order. for seq in seqs[acquired - 1 :: -1]: self.lock[seq].release() with self.backlog_lock: for seq in seqs: # Clear the backlog. We may have raised an error # causing the backlog to not be consumed entirely. if seq in self.backlog: del self.backlog[seq] self.addr_pool.free(seq, ban=0xFF) def nlm_request( self, msg, msg_type, msg_flags=NLM_F_REQUEST | NLM_F_DUMP, terminate=None, callback=None, parser=None, ): msg_seq = self.addr_pool.alloc() defer = None if callable(parser): self.marshal.seq_map[msg_seq] = parser with self.lock[msg_seq]: retry_count = 0 try: while True: try: self.put(msg, msg_type, msg_flags, msg_seq=msg_seq) if self.compiled is not None: for data in self.compiled: yield data else: for msg in self.get( msg_seq=msg_seq, terminate=terminate, callback=callback, ): # analyze the response for effects to be # deferred if ( defer is None and msg['header']['flags'] & NLM_F_DUMP_INTR ): defer = NetlinkDumpInterrupted() yield msg break except NetlinkError as e: if e.code != errno.EBUSY: raise if retry_count >= 30: raise log.warning('Error 16, retry {}.'.format(retry_count)) time.sleep(0.3) retry_count += 1 continue except Exception: raise finally: # Ban this msg_seq for 0xff rounds # # It's a long story. Modern kernels for RTM_SET.* # operations always return NLMSG_ERROR(0) == success, # even not setting NLM_F_MULTI flag on other response # messages and thus w/o any NLMSG_DONE. So, how to detect # the response end? One can not rely on NLMSG_ERROR on # old kernels, but we have to support them too. Ty, we # just ban msg_seq for several rounds, and NLMSG_ERROR, # being received, will become orphaned and just dropped. # # Hack, but true. self.addr_pool.free(msg_seq, ban=0xFF) if msg_seq in self.marshal.seq_map: self.marshal.seq_map.pop(msg_seq) if defer is not None: raise defer class BatchAddrPool: def alloc(self, *argv, **kwarg): return 0 def free(self, *argv, **kwarg): pass class BatchBacklogQueue(list): def append(self, *argv, **kwarg): pass def pop(self, *argv, **kwarg): pass class BatchBacklog(dict): def __getitem__(self, key): return BatchBacklogQueue() def __setitem__(self, key, value): pass def __delitem__(self, key): pass class BatchSocket(NetlinkSocketBase): def post_init(self): self.backlog = BatchBacklog() self.addr_pool = BatchAddrPool() self._sock = None self.reset() def reset(self): self.batch = bytearray() def nlm_request( self, msg, msg_type, msg_flags=NLM_F_REQUEST | NLM_F_DUMP, terminate=None, callback=None, ): msg_seq = self.addr_pool.alloc() msg_pid = self.epid or os.getpid() msg['header']['type'] = msg_type msg['header']['flags'] = msg_flags msg['header']['sequence_number'] = msg_seq msg['header']['pid'] = msg_pid msg.data = self.batch msg.offset = len(self.batch) msg.encode() return [] def get(self, *argv, **kwarg): pass class NetlinkSocket(NetlinkSocketBase): def post_init(self): # recreate the underlying socket with self.sys_lock: if self._sock is not None: self._sock.close() self._sock = config.SocketBase( AF_NETLINK, SOCK_DGRAM, self.family, self._fileno ) self.setsockopt(SOL_SOCKET, SO_SNDBUF, self._sndbuf) self.setsockopt(SOL_SOCKET, SO_RCVBUF, self._rcvbuf) if self.ext_ack: self.setsockopt(SOL_NETLINK, NETLINK_EXT_ACK, 1) if self.all_ns: self.setsockopt(SOL_NETLINK, NETLINK_LISTEN_ALL_NSID, 1) if self.strict_check: self.setsockopt(SOL_NETLINK, NETLINK_GET_STRICT_CHK, 1) def __getattr__(self, attr): if attr in ( 'getsockname', 'getsockopt', 'makefile', 'setsockopt', 'setblocking', 'settimeout', 'gettimeout', 'shutdown', 'recvfrom', 'recvfrom_into', 'fileno', ): return getattr(self._sock, attr) elif attr in ('_sendto', '_recv', '_recv_into'): return getattr(self._sock, attr.lstrip("_")) raise AttributeError(attr) def bind(self, groups=0, pid=None, **kwarg): ''' Bind the socket to given multicast groups, using given pid. - If pid is None, use automatic port allocation - If pid == 0, use process' pid - If pid == <int>, use the value instead of pid ''' if pid is not None: self.port = 0 self.fixed = True self.pid = pid or os.getpid() if 'async' in kwarg: # FIXME # raise deprecation error after 0.5.3 # log.warning( 'use "async_cache" instead of "async", ' '"async" is a keyword from Python 3.7' ) async_cache = kwarg.get('async_cache') or kwarg.get('async') self.groups = groups # if we have pre-defined port, use it strictly if self.fixed: self.epid = self.pid + (self.port << 22) self._sock.bind((self.epid, self.groups)) else: for port in range(1024): try: self.port = port self.epid = self.pid + (self.port << 22) self._sock.bind((self.epid, self.groups)) break except Exception: # create a new underlying socket -- on kernel 4 # one failed bind() makes the socket useless self.post_init() else: raise KeyError('no free address available') # all is OK till now, so start async recv, if we need if async_cache: self.buffer_thread = threading.Thread( name="Netlink async cache", target=self.buffer_thread_routine ) self.input_from_buffer_queue = True self.buffer_thread.daemon = True self.buffer_thread.start() def add_membership(self, group): self.setsockopt(SOL_NETLINK, NETLINK_ADD_MEMBERSHIP, group) def drop_membership(self, group): self.setsockopt(SOL_NETLINK, NETLINK_DROP_MEMBERSHIP, group) def close(self, code=errno.ECONNRESET): ''' Correctly close the socket and free all resources. ''' with self.sys_lock: if self.closed: return self.closed = True if self.buffer_thread: os.write(self._ctrl_write, b'exit') self.buffer_thread.join() super(NetlinkSocket, self).close(code=code) # Common shutdown procedure self._sock.close() class ChaoticNetlinkSocket(NetlinkSocket): success_rate = 1 def __init__(self, *argv, **kwarg): self.success_rate = kwarg.pop('success_rate', 0.7) super(ChaoticNetlinkSocket, self).__init__(*argv, **kwarg) def get(self, *argv, **kwarg): if random.random() > self.success_rate: raise ChaoticException() return super(ChaoticNetlinkSocket, self).get(*argv, **kwarg)