Created
May 9, 2023 14:43
-
-
Save pc-m/fbe3a7f857b8a468280cbc49e0d42cce to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/wazo_websocketd/bus.py b/wazo_websocketd/bus.py | |
| index a755642..9cf6375 100644 | |
| --- a/wazo_websocketd/bus.py | |
| +++ b/wazo_websocketd/bus.py | |
| @@ -74,7 +74,7 @@ class _UserHelper: | |
| class _BusConnection: | |
| - _id_counter = Value('i', 1) # process-safe shared counter | |
| + _id_counter = Value('i', 1) | |
| def __init__(self, url: str, *, loop: asyncio.AbstractEventLoop = None): | |
| self._id: int = self._get_unique_id() | |
| @@ -169,10 +169,8 @@ class _BusConnection: | |
| f'[connection {self._id}] failed to create a new channel' | |
| ) | |
| - def spawn_consumer( | |
| - self, exchange: str, token: str, *, prefetch: int = None | |
| - ) -> BusConsumer: | |
| - consumer = BusConsumer(self, exchange, token, prefetch=prefetch) | |
| + def spawn_consumer(self, config: dict, token: str) -> BusConsumer: | |
| + consumer = BusConsumer(self, config, token) | |
| self._consumers.append(consumer) | |
| return consumer | |
| @@ -227,24 +225,16 @@ class _BusConnectionPool: | |
| class BusConsumer: | |
| - DEFAULT_PREFETCH_COUNT: int = 200 | |
| - | |
| - def __init__( | |
| - self, | |
| - connection: _BusConnection, | |
| - exchange: str, | |
| - token: str, | |
| - *, | |
| - prefetch: int = None, | |
| - ): | |
| + def __init__(self, connection: _BusConnection, config: dict, token: str): | |
| self.set_token(token) | |
| self._amqp_queue: str | None = None | |
| self._bound_exchange: str | None = None | |
| self._channel: Channel = None | |
| self._connection: _BusConnection = connection | |
| self._consumer_tag: str | None = None | |
| - self._exchange_name: str = exchange | |
| - self._prefetch: int = prefetch or self.DEFAULT_PREFETCH_COUNT | |
| + self._exchange_name: str = config['bus']['exchange_name'] | |
| + self._prefetch: int = config['bus']['consumer_prefetch'] | |
| + self._origin_uuid: str = config['uuid'] | |
| self._queue = asyncio.Queue() | |
| async def __aenter__(self): | |
| @@ -257,7 +247,7 @@ class BusConsumer: | |
| def __aiter__(self): | |
| return self | |
| - async def __anext__(self) -> 'BusMessage': | |
| + async def __anext__(self) -> BusMessage: | |
| payload = await self._queue.get() | |
| if isinstance(payload, Exception): | |
| raise payload | |
| @@ -280,7 +270,10 @@ class BusConsumer: | |
| ) | |
| await channel.exchange_bind( | |
| - tenant_exchange, exchange, '', arguments={'tenant_uuid': tenant_uuid} | |
| + tenant_exchange, | |
| + exchange, | |
| + '', | |
| + arguments={'origin_uuid': self._origin_uuid, 'tenant_uuid': tenant_uuid}, | |
| ) | |
| return tenant_exchange | |
| @@ -330,6 +323,21 @@ class BusConsumer: | |
| return BusMessage(event_name, headers, acl, message, decoded) | |
| + def _generate_bindings(self, event_name: str) -> list[dict]: | |
| + binding = {} | |
| + if event_name != '*': | |
| + binding['name'] = event_name | |
| + | |
| + if self._user.is_admin(): | |
| + binding['origin_uuid'] = self._origin_uuid | |
| + return [binding] | |
| + | |
| + # note: users don't need origin_uuid because the tenant exchange takes care of it | |
| + return [ | |
| + binding | {f'user_uuid:{self._user.uuid}': True}, | |
| + binding | {'user_uuid:*': True}, | |
| + ] | |
| + | |
| def _has_access(self, acl: str) -> bool: | |
| return self._access.matches_required_access(acl) | |
| @@ -380,13 +388,7 @@ class BusConsumer: | |
| self._connection.remove_consumer(self) | |
| async def bind(self, event_name: str) -> None: | |
| - bindings = [{}] | |
| - if not self._user.is_admin(): | |
| - bindings = [{f'user_uuid:{uuid}': True} for uuid in (self._user.uuid, '*')] | |
| - | |
| - for binding in bindings: | |
| - if event_name != '*': | |
| - binding['name'] = event_name | |
| + for binding in self._generate_bindings(event_name): | |
| await self._channel.queue_bind( | |
| self._amqp_queue, self._bound_exchange, '', arguments=binding | |
| ) | |
| @@ -395,13 +397,7 @@ class BusConsumer: | |
| self._queue.put_nowait(BusConnectionLostError()) | |
| async def unbind(self, event_name: str) -> None: | |
| - bindings = [{}] | |
| - if not self._user.is_admin(): | |
| - bindings = [{f'user_uuid:{uuid}': True} for uuid in (self._user.uuid, '*')] | |
| - | |
| - for binding in bindings or [{}]: | |
| - if event_name != '*': | |
| - binding['name'] = event_name | |
| + for binding in self._generate_bindings(event_name): | |
| await self._channel.queue_unbind( | |
| self._amqp_queue, self._bound_exchange, '', arguments=binding | |
| ) | |
| @@ -446,10 +442,7 @@ class BusService: | |
| async def create_consumer(self, token: str) -> BusConsumer: | |
| connection = self._connection_pool.get_connection() | |
| - exchange = self._config['bus']['exchange_name'] | |
| - prefetch = self._config['bus']['consumer_prefetch'] | |
| - | |
| - return connection.spawn_consumer(exchange, token, prefetch=prefetch) | |
| + return connection.spawn_consumer(self._config, token) | |
| async def initialize_exchanges(self): | |
| async def create_exchange(config: dict, channel: Channel): | |
| diff --git a/wazo_websocketd/controller.py b/wazo_websocketd/controller.py | |
| index 0c3602d..1cc1ddc 100644 | |
| --- a/wazo_websocketd/controller.py | |
| +++ b/wazo_websocketd/controller.py | |
| @@ -15,13 +15,13 @@ logger = logging.getLogger(__name__) | |
| class Controller: | |
| - def __init__(self, config): | |
| + def __init__(self, config: dict): | |
| self._config = config | |
| async def _initialize(self, tombstone: Future): | |
| async with BusService(self._config) as service: | |
| - futs = {service.initialize_exchanges(), tombstone} | |
| - await asyncio.wait(futs, return_when=FIRST_COMPLETED) | |
| + results = {service.initialize_exchanges(), tombstone} | |
| + await asyncio.wait(results, return_when=FIRST_COMPLETED) | |
| async def _run(self): | |
| tombstone = asyncio.Future() | |
| @@ -40,7 +40,7 @@ class Controller: | |
| ) | |
| async with ProcessPool(self._config): | |
| - await tombstone # wait for SIGTERM or SIGINT` | |
| + await tombstone # wait for SIGTERM or SIGINT | |
| logger.info('wazo-websocketd stopped') | |
| diff --git a/wazo_websocketd/main.py b/wazo_websocketd/main.py | |
| index 7739de6..40f23ad 100644 | |
| --- a/wazo_websocketd/main.py | |
| +++ b/wazo_websocketd/main.py | |
| @@ -6,12 +6,16 @@ import logging | |
| import uvloop | |
| from xivo import xivo_logging | |
| +from xivo.config_helper import set_xivo_uuid | |
| from xivo.user_rights import change_user | |
| from wazo_websocketd.config import load_config | |
| from wazo_websocketd.controller import Controller | |
| +logger = logging.getLogger(__name__) | |
| + | |
| + | |
| def main(): | |
| asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) | |
| config = load_config() | |
| @@ -21,6 +25,7 @@ def main(): | |
| ) | |
| xivo_logging.silence_loggers(['urllib3'], logging.WARNING) | |
| xivo_logging.silence_loggers(['aioamqp'], logging.WARNING) | |
| + set_xivo_uuid(config, logger) | |
| if config['user']: | |
| change_user(config['user']) | |
| diff --git a/wazo_websocketd/process.py b/wazo_websocketd/process.py | |
| index fcb4b65..d0935c5 100644 | |
| --- a/wazo_websocketd/process.py | |
| +++ b/wazo_websocketd/process.py | |
| @@ -8,9 +8,10 @@ import websockets | |
| from multiprocessing import get_context | |
| from multiprocessing.sharedctypes import Synchronized | |
| -from os import getpid, sched_getaffinity | |
| +from os import getpid, sched_getaffinity, chdir | |
| from setproctitle import setproctitle | |
| from signal import SIGINT, SIGTERM | |
| +from tempfile import TemporaryDirectory | |
| from websockets.server import Serve | |
| from xivo.xivo_logging import silence_loggers, setup_logging | |
| @@ -73,8 +74,10 @@ class ProcessPool: | |
| ) | |
| self._workers = workers | |
| self._config = config | |
| + self._dir = TemporaryDirectory(prefix="wazo-websocketd-") | |
| context = get_context('spawn') | |
| + chdir(self._dir.name) | |
| self._pool = context.Pool( | |
| workers, self._init_worker, (config, MasterTenantProxy.proxy) | |
| ) | |
| @@ -88,6 +91,7 @@ class ProcessPool: | |
| async def __aexit__(self, *args): | |
| self._pool.close() | |
| self._pool.join() | |
| + self._dir.cleanup() | |
| @staticmethod | |
| def _init_worker(config: dict, master_tenant_proxy: Synchronized): | |
| @@ -99,7 +103,7 @@ class ProcessPool: | |
| setup_logging( | |
| config['log_file'], debug=config['debug'], log_level=config['log_level'] | |
| ) | |
| - silence_loggers(['aioamqp', 'urllib3'], logging.WARNING) | |
| + silence_loggers(['aioamqp', 'urllib3', 'stevedore.extension'], logging.WARNING) | |
| @staticmethod | |
| def _run(config: dict): |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment