Module server.servercontext
Manages a group of connections using the same protocol over the same port
Classes
class ServerContext (name: str, connection_factory: Callable[[], LobbyConnection], services: Iterable[Service], protocol_class: type[Protocol] = server.protocol.qdatastream.QDataStreamProtocol)
-
Base class for managing connections and holding state about them.
Expand source code
@with_logger class ServerContext: """ Base class for managing connections and holding state about them. """ def __init__( self, name: str, connection_factory: Callable[[], LobbyConnection], services: Iterable[Service], protocol_class: type[Protocol] = QDataStreamProtocol, ): super().__init__() self.name = name self._server = None self._drain_event = None self._connection_factory = connection_factory self._services = services self.connections: dict[LobbyConnection, Protocol] = {} self.protocol_class = protocol_class def __repr__(self): return f"ServerContext({self.name})" async def listen( self, host: str, port: Optional[int], proxy: bool = False ): self._logger.debug( "%s: listen(%r, %r, proxy=%r)", self.name, host, port, proxy ) callback = self.client_connected_callback if proxy: pp_detect = ProxyProtocolDetect() pp_reader = ProxyProtocolReader(pp_detect) callback = pp_reader.get_callback(callback) self._server = await asyncio.start_server( callback, host=host, port=port, limit=LIMIT, ) for sock in self.sockets: sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) host, port, *_ = sock.getsockname() self._logger.info("%s: listening on %s:%s", self.name, host, port) return self._server @property def sockets(self): return self._server.sockets async def shutdown(self, timeout: Optional[float] = 5): async def close_or_abort(conn, proto): try: await asyncio.wait_for(proto.close(), timeout) except asyncio.TimeoutError: proto.abort() self._logger.warning( "%s: Protocol did not terminate cleanly for '%s'", self.name, conn.get_user_identifier() ) self._logger.debug( "%s: Waiting up to %s for connections to close", self.name, humanize.naturaldelta(timeout) ) for fut in asyncio.as_completed([ close_or_abort(conn, proto) for conn, proto in self.connections.items() ]): await fut self._logger.debug("%s: All connections closed", self.name) async def stop(self): self._logger.debug("%s: stop()", self.name) if self._server: self._server.close() await self._server.wait_closed() async def drain_connections(self): """ Wait for all connections to terminate. """ if not self.connections: return if not self._drain_event: self._drain_event = asyncio.Event() await self._drain_event.wait() def write_broadcast(self, message, validate_fn=lambda _: True): self.write_broadcast_raw( self.protocol_class.encode_message(message), validate_fn ) def write_broadcast_raw(self, data, validate_fn=lambda _: True): for conn, proto in self.connections.items(): try: if proto.is_connected() and validate_fn(conn): proto.write_raw(data) except Exception: self._logger.exception( "%s: Encountered error in broadcast: %s", self.name, conn ) async def client_connected_callback( self, reader: StreamReader, writer: StreamWriter, proxy_info: Optional[SocketInfo] = None, ): if proxy_info: peername_writer = Address(*writer.get_extra_info("peername")) if not proxy_info.peername: # See security considerations: # https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt self._logger.warning( "%s: Client connected from %s:%s to a context in proxy " "mode! The connection will be ignored, however this may " "indicate a misconfiguration in your firewall.", self.name, peername_writer.host, peername_writer.port ) writer.close() return peername = Address(*proxy_info.peername) self._logger.info( "%s: Client connected from %s:%s via proxy %s:%s", self.name, peername.host, peername.port, peername_writer.host, peername_writer.port ) else: peername = Address(*writer.get_extra_info("peername")) self._logger.info( "%s: Client connected from %s:%s", self.name, peername.host, peername.port ) await self.handle_client_connected(reader, writer, peername) async def handle_client_connected( self, reader: StreamReader, writer: StreamWriter, peername: Address, ): protocol = self.protocol_class(reader, writer) connection = self._connection_factory() self.connections[connection] = protocol try: await connection.on_connection_made(protocol, peername) metrics.user_connections.labels("None", "None").inc() while protocol.is_connected(): message = await protocol.read_message() with metrics.connection_on_message_received.time(): await connection.on_message_received(message) except ( ConnectionError, DisconnectedError, TimeoutError, asyncio.CancelledError, ): pass except UnicodeDecodeError as e: self._logger.exception( "%s: Unicode error in protocol for '%s': %s '...%s...'", self.name, connection.get_user_identifier(), e, e.object[e.start-20:e.end+20] ) except Exception as e: self._logger.exception( "%s: Exception in protocol for '%s': %s", self.name, connection.get_user_identifier(), e ) finally: del self.connections[connection] # Do not wait for buffers to empty here. This could stop the process # from exiting if the client isn't reading data. protocol.abort() for service in self._services: with self.suppress_and_log(service.on_connection_lost, Exception): service.on_connection_lost(connection) with self.suppress_and_log(connection.on_connection_lost, Exception): await connection.on_connection_lost() self._logger.info( "%s: Client disconnected for '%s'", self.name, connection.get_user_identifier() ) if ( self._drain_event is not None and not self._drain_event.is_set() and not self.connections ): self._drain_event.set() metrics.user_connections.labels( connection.user_agent, connection.version ).dec() @contextmanager def suppress_and_log(self, func, *exceptions: type[BaseException]): try: yield except exceptions: if hasattr(func, "__self__"): desc = f"{func.__self__.__class__.__name__}.{func.__name__}" else: desc = func.__name__ self._logger.warning( "Unexpected exception in %s", desc, exc_info=True )
Instance variables
prop sockets
-
Expand source code
@property def sockets(self): return self._server.sockets
Methods
async def client_connected_callback(self, reader: asyncio.streams.StreamReader, writer: asyncio.streams.StreamWriter, proxy_info: Optional[proxyprotocol.sock.SocketInfo] = None)
async def drain_connections(self)
-
Wait for all connections to terminate.
async def handle_client_connected(self, reader: asyncio.streams.StreamReader, writer: asyncio.streams.StreamWriter, peername: Address)
async def listen(self, host: str, port: Optional[int], proxy: bool = False)
async def shutdown(self, timeout: Optional[float] = 5)
async def stop(self)
def suppress_and_log(self, func, *exceptions: type[BaseException])
def write_broadcast(self, message, validate_fn=<function ServerContext.<lambda>>)
def write_broadcast_raw(self, data, validate_fn=<function ServerContext.<lambda>>)