Skip to content

Instantly share code, notes, and snippets.

@pc-m
Created May 9, 2023 14:43
Show Gist options
  • Select an option

  • Save pc-m/fbe3a7f857b8a468280cbc49e0d42cce to your computer and use it in GitHub Desktop.

Select an option

Save pc-m/fbe3a7f857b8a468280cbc49e0d42cce to your computer and use it in GitHub Desktop.
diff --git a/wazo_websocketd/bus.py b/wazo_websocketd/bus.py
index a755642..9cf6375 100644
--- a/wazo_websocketd/bus.py
+++ b/wazo_websocketd/bus.py
@@ -74,7 +74,7 @@ class _UserHelper:
class _BusConnection:
- _id_counter = Value('i', 1) # process-safe shared counter
+ _id_counter = Value('i', 1)
def __init__(self, url: str, *, loop: asyncio.AbstractEventLoop = None):
self._id: int = self._get_unique_id()
@@ -169,10 +169,8 @@ class _BusConnection:
f'[connection {self._id}] failed to create a new channel'
)
- def spawn_consumer(
- self, exchange: str, token: str, *, prefetch: int = None
- ) -> BusConsumer:
- consumer = BusConsumer(self, exchange, token, prefetch=prefetch)
+ def spawn_consumer(self, config: dict, token: str) -> BusConsumer:
+ consumer = BusConsumer(self, config, token)
self._consumers.append(consumer)
return consumer
@@ -227,24 +225,16 @@ class _BusConnectionPool:
class BusConsumer:
- DEFAULT_PREFETCH_COUNT: int = 200
-
- def __init__(
- self,
- connection: _BusConnection,
- exchange: str,
- token: str,
- *,
- prefetch: int = None,
- ):
+ def __init__(self, connection: _BusConnection, config: dict, token: str):
self.set_token(token)
self._amqp_queue: str | None = None
self._bound_exchange: str | None = None
self._channel: Channel = None
self._connection: _BusConnection = connection
self._consumer_tag: str | None = None
- self._exchange_name: str = exchange
- self._prefetch: int = prefetch or self.DEFAULT_PREFETCH_COUNT
+ self._exchange_name: str = config['bus']['exchange_name']
+ self._prefetch: int = config['bus']['consumer_prefetch']
+ self._origin_uuid: str = config['uuid']
self._queue = asyncio.Queue()
async def __aenter__(self):
@@ -257,7 +247,7 @@ class BusConsumer:
def __aiter__(self):
return self
- async def __anext__(self) -> 'BusMessage':
+ async def __anext__(self) -> BusMessage:
payload = await self._queue.get()
if isinstance(payload, Exception):
raise payload
@@ -280,7 +270,10 @@ class BusConsumer:
)
await channel.exchange_bind(
- tenant_exchange, exchange, '', arguments={'tenant_uuid': tenant_uuid}
+ tenant_exchange,
+ exchange,
+ '',
+ arguments={'origin_uuid': self._origin_uuid, 'tenant_uuid': tenant_uuid},
)
return tenant_exchange
@@ -330,6 +323,21 @@ class BusConsumer:
return BusMessage(event_name, headers, acl, message, decoded)
+ def _generate_bindings(self, event_name: str) -> list[dict]:
+ binding = {}
+ if event_name != '*':
+ binding['name'] = event_name
+
+ if self._user.is_admin():
+ binding['origin_uuid'] = self._origin_uuid
+ return [binding]
+
+ # note: users don't need origin_uuid because the tenant exchange takes care of it
+ return [
+ binding | {f'user_uuid:{self._user.uuid}': True},
+ binding | {'user_uuid:*': True},
+ ]
+
def _has_access(self, acl: str) -> bool:
return self._access.matches_required_access(acl)
@@ -380,13 +388,7 @@ class BusConsumer:
self._connection.remove_consumer(self)
async def bind(self, event_name: str) -> None:
- bindings = [{}]
- if not self._user.is_admin():
- bindings = [{f'user_uuid:{uuid}': True} for uuid in (self._user.uuid, '*')]
-
- for binding in bindings:
- if event_name != '*':
- binding['name'] = event_name
+ for binding in self._generate_bindings(event_name):
await self._channel.queue_bind(
self._amqp_queue, self._bound_exchange, '', arguments=binding
)
@@ -395,13 +397,7 @@ class BusConsumer:
self._queue.put_nowait(BusConnectionLostError())
async def unbind(self, event_name: str) -> None:
- bindings = [{}]
- if not self._user.is_admin():
- bindings = [{f'user_uuid:{uuid}': True} for uuid in (self._user.uuid, '*')]
-
- for binding in bindings or [{}]:
- if event_name != '*':
- binding['name'] = event_name
+ for binding in self._generate_bindings(event_name):
await self._channel.queue_unbind(
self._amqp_queue, self._bound_exchange, '', arguments=binding
)
@@ -446,10 +442,7 @@ class BusService:
async def create_consumer(self, token: str) -> BusConsumer:
connection = self._connection_pool.get_connection()
- exchange = self._config['bus']['exchange_name']
- prefetch = self._config['bus']['consumer_prefetch']
-
- return connection.spawn_consumer(exchange, token, prefetch=prefetch)
+ return connection.spawn_consumer(self._config, token)
async def initialize_exchanges(self):
async def create_exchange(config: dict, channel: Channel):
diff --git a/wazo_websocketd/controller.py b/wazo_websocketd/controller.py
index 0c3602d..1cc1ddc 100644
--- a/wazo_websocketd/controller.py
+++ b/wazo_websocketd/controller.py
@@ -15,13 +15,13 @@ logger = logging.getLogger(__name__)
class Controller:
- def __init__(self, config):
+ def __init__(self, config: dict):
self._config = config
async def _initialize(self, tombstone: Future):
async with BusService(self._config) as service:
- futs = {service.initialize_exchanges(), tombstone}
- await asyncio.wait(futs, return_when=FIRST_COMPLETED)
+ results = {service.initialize_exchanges(), tombstone}
+ await asyncio.wait(results, return_when=FIRST_COMPLETED)
async def _run(self):
tombstone = asyncio.Future()
@@ -40,7 +40,7 @@ class Controller:
)
async with ProcessPool(self._config):
- await tombstone # wait for SIGTERM or SIGINT`
+ await tombstone # wait for SIGTERM or SIGINT
logger.info('wazo-websocketd stopped')
diff --git a/wazo_websocketd/main.py b/wazo_websocketd/main.py
index 7739de6..40f23ad 100644
--- a/wazo_websocketd/main.py
+++ b/wazo_websocketd/main.py
@@ -6,12 +6,16 @@ import logging
import uvloop
from xivo import xivo_logging
+from xivo.config_helper import set_xivo_uuid
from xivo.user_rights import change_user
from wazo_websocketd.config import load_config
from wazo_websocketd.controller import Controller
+logger = logging.getLogger(__name__)
+
+
def main():
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
config = load_config()
@@ -21,6 +25,7 @@ def main():
)
xivo_logging.silence_loggers(['urllib3'], logging.WARNING)
xivo_logging.silence_loggers(['aioamqp'], logging.WARNING)
+ set_xivo_uuid(config, logger)
if config['user']:
change_user(config['user'])
diff --git a/wazo_websocketd/process.py b/wazo_websocketd/process.py
index fcb4b65..d0935c5 100644
--- a/wazo_websocketd/process.py
+++ b/wazo_websocketd/process.py
@@ -8,9 +8,10 @@ import websockets
from multiprocessing import get_context
from multiprocessing.sharedctypes import Synchronized
-from os import getpid, sched_getaffinity
+from os import getpid, sched_getaffinity, chdir
from setproctitle import setproctitle
from signal import SIGINT, SIGTERM
+from tempfile import TemporaryDirectory
from websockets.server import Serve
from xivo.xivo_logging import silence_loggers, setup_logging
@@ -73,8 +74,10 @@ class ProcessPool:
)
self._workers = workers
self._config = config
+ self._dir = TemporaryDirectory(prefix="wazo-websocketd-")
context = get_context('spawn')
+ chdir(self._dir.name)
self._pool = context.Pool(
workers, self._init_worker, (config, MasterTenantProxy.proxy)
)
@@ -88,6 +91,7 @@ class ProcessPool:
async def __aexit__(self, *args):
self._pool.close()
self._pool.join()
+ self._dir.cleanup()
@staticmethod
def _init_worker(config: dict, master_tenant_proxy: Synchronized):
@@ -99,7 +103,7 @@ class ProcessPool:
setup_logging(
config['log_file'], debug=config['debug'], log_level=config['log_level']
)
- silence_loggers(['aioamqp', 'urllib3'], logging.WARNING)
+ silence_loggers(['aioamqp', 'urllib3', 'stevedore.extension'], logging.WARNING)
@staticmethod
def _run(config: dict):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment