diff --git a/mreg/mqsender.py b/mreg/mqsender.py index 399cacbb..aa46e7c2 100644 --- a/mreg/mqsender.py +++ b/mreg/mqsender.py @@ -1,72 +1,89 @@ -from django.conf import settings - -from datetime import datetime, timezone; import json -import os -import pika -import random import ssl -import string import time +from datetime import datetime, timezone + +import pika +from django.conf import settings +from pika.exceptions import (AMQPConnectionError, ConnectionClosedByBroker, + StreamLostError) + + +class MQSender: -from pika.exceptions import (ConnectionClosedByBroker, StreamLostError, AMQPConnectionError) + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(MQSender, cls).__new__(cls) + return cls.instance + def __init__(self): + self.mq_channel = None + self.mq_id : int = time.time_ns()//1_000_000 -class MQSender(object): + # Check the configuration + config = getattr(settings, 'MQ_CONFIG', None) - def __new__(cls): - if not hasattr(cls, 'instance'): - cls.instance = super(MQSender, cls).__new__(cls) - return cls.instance + # We accept an empty configuration, this disables the MQ + if config is None: + return + + # However, if we have a configuration, it must be a dictionary, + # and it must contain the keys host, username, password, and exchange + if not isinstance(config, dict): + raise ValueError('MQ_CONFIG must be a dictionary') + + for key in ['host', 'username', 'password', 'exchange']: + if key not in config: + raise ValueError(f'MQ_CONFIG must contain the key {key}') - def __init__(self): - self.mq_channel = None - self.mq_id : int = time.time_ns()//1_000_000 - def send_event(self, obj, routing_key): - config = getattr(settings, 'MQ_CONFIG', None) - if config is None: - return + def send_event(self, obj, routing_key): + config = getattr(settings, 'MQ_CONFIG', None) + if config is None: + return - # Add an id property to the event - obj['id'] = self.mq_id - self.mq_id += 1 + # Add an id property to the event + obj['id'] = self.mq_id + self.mq_id += 1 - # Add a timestamp to the event - local_time = datetime.now(timezone.utc).astimezone() - obj['timestamp'] = local_time.isoformat() + # Add a timestamp to the event + local_time = datetime.now(timezone.utc).astimezone() + obj['timestamp'] = local_time.isoformat() - for retry in range(10): - if self.mq_channel is None or self.mq_channel.connection.is_closed: - credentials = pika.credentials.PlainCredentials( - username=config['username'], - password=config['password'], - ) - ssl_options = None - if config.get('ssl',False): - ssl_context = ssl.create_default_context() - ssl_options = pika.SSLOptions(ssl_context, config['host']) - connection_parameters = pika.ConnectionParameters( - host=config['host'], - credentials=credentials, - ssl_options = ssl_options, - virtual_host = config.get('virtual_host','/'), - ) - try: - connection = pika.BlockingConnection(connection_parameters) - self.mq_channel = connection.channel() - if config.get('declare',False): - self.mq_channel.exchange_declare(exchange=config['exchange'], exchange_type='topic') - except AMQPConnectionError: - continue + for retry in range(10): + if self.mq_channel is None or self.mq_channel.connection.is_closed: + credentials = pika.credentials.PlainCredentials( + username=config['username'], + password=config['password'], + ) + ssl_options = None + if config.get('ssl', False): + ssl_context = ssl.create_default_context() + ssl_options = pika.SSLOptions(ssl_context, config['host']) + connection_parameters = pika.ConnectionParameters( + host=config['host'], + credentials=credentials, + ssl_options = ssl_options, + virtual_host = config.get('virtual_host','/'), + ) + try: + connection = pika.BlockingConnection(connection_parameters) + self.mq_channel = connection.channel() + if config.get('declare',False): + self.mq_channel.exchange_declare( + exchange=config['exchange'], + exchange_type='topic' + ) + except AMQPConnectionError: + continue - try: - self.mq_channel.basic_publish( - exchange=config['exchange'], - routing_key=routing_key, - body=json.dumps(obj), - properties=pika.BasicProperties(content_type="application/json"), - ) - break - except (ConnectionClosedByBroker, StreamLostError): - self.mq_channel = None + try: + self.mq_channel.basic_publish( + exchange=config['exchange'], + routing_key=routing_key, + body=json.dumps(obj), + properties=pika.BasicProperties(content_type="application/json"), + ) + break + except (ConnectionClosedByBroker, StreamLostError): + self.mq_channel = None diff --git a/mreg/tests.py b/mreg/tests.py index 2d046c5c..96bec99d 100644 --- a/mreg/tests.py +++ b/mreg/tests.py @@ -1,18 +1,27 @@ +import ipaddress +import json +import signal +import ssl from datetime import timedelta +from itertools import combinations +from unittest import mock from django.core.exceptions import ValidationError -from django.test import TestCase +from django.test import TestCase, override_settings from django.utils import timezone - +from pika import SSLOptions +from pika.exceptions import (AMQPConnectionError, ConnectionClosedByBroker, + StreamLostError) from rest_framework.exceptions import PermissionDenied -from .models import (Cname, ForwardZone, Host, HostGroup, Ipaddress, - Loc, NameServer, Naptr, +from mreg.mqsender import MQSender + +from .models import (MAX_UNUSED_LIST, Cname, ForwardZone, Host, HostGroup, + Ipaddress, Loc, NameServer, Naptr, NetGroupRegexPermission, Network, NetworkExcludedRange, - PtrOverride, ReverseZone, Srv, Sshfp, Txt, - MAX_UNUSED_LIST) + PtrOverride, ReverseZone, Srv, Sshfp, Txt) +from .mqsender import MQSender -import signal, ipaddress def clean_and_save(entity): entity.full_clean() @@ -1139,3 +1148,178 @@ def test_model_clean_permissions(self): self.assertEqual(NetGroupRegexPermission.objects.first(), v6perm) self.network_v6.delete() self.assertEqual(NetGroupRegexPermission.objects.count(), 0) + + +class MQSenderTest(TestCase): + """Test the MQSender class.""" + + def setUp(self): + """Set up the MQSender tests.""" + self.override_settings = override_settings(MQ_CONFIG={ + 'username': 'test', + 'password': 'test', + 'host': 'localhost', + 'ssl': False, + 'virtual_host': '/', + 'exchange': 'test_exchange' + }) + self.override_settings.enable() + self.mock_blocking_connection = mock.patch('pika.BlockingConnection') + + def tearDown(self): + """Tear down the MQSender tests.""" + self.override_settings.disable() + + def create_mq_sender_and_send_event(self, connection_side_effect=None): + """Create a MQSender and send an event, set side effects if requested.""" + with mock.patch('pika.BlockingConnection') as mock_blocking_connection: + if connection_side_effect is not None: + mock_blocking_connection.side_effect = connection_side_effect + mq_sender = MQSender() + mq_sender.send_event({'test': 'test'}, 'test_route') + return mq_sender + + def assert_event_not_published(self, mq_sender): + """Validate the channel is gone or the event is not published.""" + if mq_sender.mq_channel is not None: + assert not mq_sender.mq_channel.basic_publish.called + + def assert_event_published_correctly(self, mq_sender): + """Validate the event published has correct data.""" + assert mq_sender.mq_channel.basic_publish.called + _, kwargs = mq_sender.mq_channel.basic_publish.call_args + body = kwargs['body'] + body_dict = json.loads(body) + assert 'test' in body_dict + assert 'id' in body_dict + assert 'timestamp' in body_dict + assert body_dict['test'] == 'test' + + def test_send_event_default(self): + """Test that the event is published correctly.""" + mq_sender = self.create_mq_sender_and_send_event() + self.assert_event_published_correctly(mq_sender) + + def test_send_event_exception(self): + """Test that the event is published if the connection fails but retries work.""" + connection_side_effects = [AMQPConnectionError] * 3 + [mock.MagicMock()] + mq_sender = self.create_mq_sender_and_send_event(connection_side_effects) + self.assert_event_published_correctly(mq_sender) + + def test_send_event_failure_exception_ten_times(self): + """Test that the event is not published if the connection fails more than 10 times.""" + connection_side_effects = [AMQPConnectionError] * 11 + mq_sender = self.create_mq_sender_and_send_event(connection_side_effects) + self.assert_event_not_published(mq_sender) + + def test_singleton(self): + """Test that the MQSender class is a proper singleton.""" + instance1 = MQSender() + instance2 = MQSender() + self.assertEqual(instance1, instance2) + + def test_initialization(self): + """Test that the MQSender class is properly initialized.""" + mq_sender = MQSender() + self.assertIsNone(mq_sender.mq_channel) + self.assertIsInstance(mq_sender.mq_id, int) + + @mock.patch('pika.BlockingConnection') + def test_send_event_config(self, mock_blocking_connection): + """Test that we handle different configuration options.""" + keys = ['host', 'username', 'password', 'exchange'] + full_config = { + 'host': 'localhost', + 'username': 'test', + 'password': 'test', + 'exchange': 'test_exchange' + } + + # Test with MQ_CONFIG=None + with self.settings(MQ_CONFIG=None): + mq_sender = MQSender() + mq_sender.send_event({'test': 'test'}, 'test') + mock_blocking_connection.basic_publish.assert_not_called() + + # Test with MQ_CONFIG not a dictionary + with self.settings(MQ_CONFIG=[]): + with self.assertRaises(ValueError): + mq_sender = MQSender() + + # Test all combinations of missing keys + for r in range(1, len(keys) + 1): + for combination in combinations(keys, r): + test_config = full_config.copy() + for key in combination: + del test_config[key] + with self.settings(MQ_CONFIG=test_config): + with self.assertRaises(ValueError): + mq_sender = MQSender() + + @mock.patch('ssl.SSLContext.load_cert_chain') + @mock.patch('pika.BlockingConnection') + def test_send_event_with_ssl_and_declare( + self, + mock_blocking_connection, + mock_load_cert_chain + ): + """Test that SSL and declare options are handled correctly.""" + mock_connection = mock.MagicMock() + mock_channel = mock.MagicMock() + mock_connection.channel.return_value = mock_channel + mock_blocking_connection.return_value = mock_connection + + ssl_context = ssl.create_default_context() + ssl_options = SSLOptions(ssl_context, 'localhost') + + with override_settings(MQ_CONFIG={ + 'username': 'test', + 'password': 'test', + 'host': 'localhost', + 'ssl': ssl_options, + 'declare': True, + 'virtual_host': '/', + 'exchange': 'test_exchange' + }): + mq_sender = MQSender() + mq_sender.send_event({'test': 'test'}, 'test_route') + + # Assert that exchange_declare was called if 'declare' option is set to True + mock_channel.exchange_declare.assert_called_once_with( + exchange='test_exchange', + exchange_type='topic' + ) + + def _run_basic_publish_test(self, side_effects, expect_published): + with mock.patch('pika.BlockingConnection') as mock_blocking_connection: + mock_connection = mock.MagicMock() + mock_channel = mock.MagicMock() + mock_channel.basic_publish.side_effect = side_effects + mock_connection.channel.return_value = mock_channel + mock_blocking_connection.return_value = mock_connection + + mq_sender = MQSender() + mq_sender.send_event({'test': 'test'}, 'test_route') + + if expect_published: + self.assert_event_published_correctly(mq_sender) + else: + self.assertIsNone(mq_sender.mq_channel) + + def test_send_event_basic_publish_raises_exceptions(self): + """Test that exceptions in basic_publish are handled correctly. + + We specifically handle ConnectionClosedByBroker and StreamLostError.""" + + # Test when exceptions are raised and no event is published + self._run_basic_publish_test([ + ConnectionClosedByBroker(reply_code=123, reply_text='test'), + StreamLostError] * 6, + expect_published=False) + + # Test when exceptions are raised, but eventually an event is published + self._run_basic_publish_test([ + ConnectionClosedByBroker(reply_code=123, reply_text='test'), + StreamLostError, + mock.MagicMock()], + expect_published=True) \ No newline at end of file