Skip to content

Commit

Permalink
Test mqsender, validate input, style update.
Browse files Browse the repository at this point in the history
  - Complete test suite for MQSender.
  - Adds validation (and testing) of MQ_CONFIG keys.
  - Isorted imports in affected files due to lots of changes.
  - Tab to space conversion in mqsender.py.
  • Loading branch information
terjekv authored and oyvindhagberg committed Jun 7, 2023
1 parent 806be01 commit e8424ee
Show file tree
Hide file tree
Showing 2 changed files with 267 additions and 66 deletions.
135 changes: 76 additions & 59 deletions mreg/mqsender.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,89 @@
from django.conf import settings

from datetime import datetime, timezone;
import json
import os
import pika
import random
import ssl
import string
import time
from datetime import datetime, timezone

import pika
from django.conf import settings
from pika.exceptions import (AMQPConnectionError, ConnectionClosedByBroker,
StreamLostError)


class MQSender:

from pika.exceptions import (ConnectionClosedByBroker, StreamLostError, AMQPConnectionError)
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(MQSender, cls).__new__(cls)
return cls.instance

def __init__(self):
self.mq_channel = None
self.mq_id : int = time.time_ns()//1_000_000

class MQSender(object):
# Check the configuration
config = getattr(settings, 'MQ_CONFIG', None)

def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(MQSender, cls).__new__(cls)
return cls.instance
# We accept an empty configuration, this disables the MQ
if config is None:
return

# However, if we have a configuration, it must be a dictionary,
# and it must contain the keys host, username, password, and exchange
if not isinstance(config, dict):
raise ValueError('MQ_CONFIG must be a dictionary')

for key in ['host', 'username', 'password', 'exchange']:
if key not in config:
raise ValueError(f'MQ_CONFIG must contain the key {key}')

def __init__(self):
self.mq_channel = None
self.mq_id : int = time.time_ns()//1_000_000

def send_event(self, obj, routing_key):
config = getattr(settings, 'MQ_CONFIG', None)
if config is None:
return
def send_event(self, obj, routing_key):
config = getattr(settings, 'MQ_CONFIG', None)
if config is None:
return

# Add an id property to the event
obj['id'] = self.mq_id
self.mq_id += 1
# Add an id property to the event
obj['id'] = self.mq_id
self.mq_id += 1

# Add a timestamp to the event
local_time = datetime.now(timezone.utc).astimezone()
obj['timestamp'] = local_time.isoformat()
# Add a timestamp to the event
local_time = datetime.now(timezone.utc).astimezone()
obj['timestamp'] = local_time.isoformat()

for retry in range(10):
if self.mq_channel is None or self.mq_channel.connection.is_closed:
credentials = pika.credentials.PlainCredentials(
username=config['username'],
password=config['password'],
)
ssl_options = None
if config.get('ssl',False):
ssl_context = ssl.create_default_context()
ssl_options = pika.SSLOptions(ssl_context, config['host'])
connection_parameters = pika.ConnectionParameters(
host=config['host'],
credentials=credentials,
ssl_options = ssl_options,
virtual_host = config.get('virtual_host','/'),
)
try:
connection = pika.BlockingConnection(connection_parameters)
self.mq_channel = connection.channel()
if config.get('declare',False):
self.mq_channel.exchange_declare(exchange=config['exchange'], exchange_type='topic')
except AMQPConnectionError:
continue
for retry in range(10):
if self.mq_channel is None or self.mq_channel.connection.is_closed:
credentials = pika.credentials.PlainCredentials(
username=config['username'],
password=config['password'],
)
ssl_options = None
if config.get('ssl', False):
ssl_context = ssl.create_default_context()
ssl_options = pika.SSLOptions(ssl_context, config['host'])
connection_parameters = pika.ConnectionParameters(
host=config['host'],
credentials=credentials,
ssl_options = ssl_options,
virtual_host = config.get('virtual_host','/'),
)
try:
connection = pika.BlockingConnection(connection_parameters)
self.mq_channel = connection.channel()
if config.get('declare',False):
self.mq_channel.exchange_declare(
exchange=config['exchange'],
exchange_type='topic'
)
except AMQPConnectionError:
continue

try:
self.mq_channel.basic_publish(
exchange=config['exchange'],
routing_key=routing_key,
body=json.dumps(obj),
properties=pika.BasicProperties(content_type="application/json"),
)
break
except (ConnectionClosedByBroker, StreamLostError):
self.mq_channel = None
try:
self.mq_channel.basic_publish(
exchange=config['exchange'],
routing_key=routing_key,
body=json.dumps(obj),
properties=pika.BasicProperties(content_type="application/json"),
)
break
except (ConnectionClosedByBroker, StreamLostError):
self.mq_channel = None
198 changes: 191 additions & 7 deletions mreg/tests.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
import ipaddress
import json
import signal
import ssl
from datetime import timedelta
from itertools import combinations
from unittest import mock

from django.core.exceptions import ValidationError
from django.test import TestCase
from django.test import TestCase, override_settings
from django.utils import timezone

from pika import SSLOptions
from pika.exceptions import (AMQPConnectionError, ConnectionClosedByBroker,
StreamLostError)
from rest_framework.exceptions import PermissionDenied

from .models import (Cname, ForwardZone, Host, HostGroup, Ipaddress,
Loc, NameServer, Naptr,
from mreg.mqsender import MQSender

from .models import (MAX_UNUSED_LIST, Cname, ForwardZone, Host, HostGroup,
Ipaddress, Loc, NameServer, Naptr,
NetGroupRegexPermission, Network, NetworkExcludedRange,
PtrOverride, ReverseZone, Srv, Sshfp, Txt,
MAX_UNUSED_LIST)
PtrOverride, ReverseZone, Srv, Sshfp, Txt)
from .mqsender import MQSender

import signal, ipaddress

def clean_and_save(entity):
entity.full_clean()
Expand Down Expand Up @@ -1139,3 +1148,178 @@ def test_model_clean_permissions(self):
self.assertEqual(NetGroupRegexPermission.objects.first(), v6perm)
self.network_v6.delete()
self.assertEqual(NetGroupRegexPermission.objects.count(), 0)


class MQSenderTest(TestCase):
"""Test the MQSender class."""

def setUp(self):
"""Set up the MQSender tests."""
self.override_settings = override_settings(MQ_CONFIG={
'username': 'test',
'password': 'test',
'host': 'localhost',
'ssl': False,
'virtual_host': '/',
'exchange': 'test_exchange'
})
self.override_settings.enable()
self.mock_blocking_connection = mock.patch('pika.BlockingConnection')

def tearDown(self):
"""Tear down the MQSender tests."""
self.override_settings.disable()

def create_mq_sender_and_send_event(self, connection_side_effect=None):
"""Create a MQSender and send an event, set side effects if requested."""
with mock.patch('pika.BlockingConnection') as mock_blocking_connection:
if connection_side_effect is not None:
mock_blocking_connection.side_effect = connection_side_effect
mq_sender = MQSender()
mq_sender.send_event({'test': 'test'}, 'test_route')
return mq_sender

def assert_event_not_published(self, mq_sender):
"""Validate the channel is gone or the event is not published."""
if mq_sender.mq_channel is not None:
assert not mq_sender.mq_channel.basic_publish.called

def assert_event_published_correctly(self, mq_sender):
"""Validate the event published has correct data."""
assert mq_sender.mq_channel.basic_publish.called
_, kwargs = mq_sender.mq_channel.basic_publish.call_args
body = kwargs['body']
body_dict = json.loads(body)
assert 'test' in body_dict
assert 'id' in body_dict
assert 'timestamp' in body_dict
assert body_dict['test'] == 'test'

def test_send_event_default(self):
"""Test that the event is published correctly."""
mq_sender = self.create_mq_sender_and_send_event()
self.assert_event_published_correctly(mq_sender)

def test_send_event_exception(self):
"""Test that the event is published if the connection fails but retries work."""
connection_side_effects = [AMQPConnectionError] * 3 + [mock.MagicMock()]
mq_sender = self.create_mq_sender_and_send_event(connection_side_effects)
self.assert_event_published_correctly(mq_sender)

def test_send_event_failure_exception_ten_times(self):
"""Test that the event is not published if the connection fails more than 10 times."""
connection_side_effects = [AMQPConnectionError] * 11
mq_sender = self.create_mq_sender_and_send_event(connection_side_effects)
self.assert_event_not_published(mq_sender)

def test_singleton(self):
"""Test that the MQSender class is a proper singleton."""
instance1 = MQSender()
instance2 = MQSender()
self.assertEqual(instance1, instance2)

def test_initialization(self):
"""Test that the MQSender class is properly initialized."""
mq_sender = MQSender()
self.assertIsNone(mq_sender.mq_channel)
self.assertIsInstance(mq_sender.mq_id, int)

@mock.patch('pika.BlockingConnection')
def test_send_event_config(self, mock_blocking_connection):
"""Test that we handle different configuration options."""
keys = ['host', 'username', 'password', 'exchange']
full_config = {
'host': 'localhost',
'username': 'test',
'password': 'test',
'exchange': 'test_exchange'
}

# Test with MQ_CONFIG=None
with self.settings(MQ_CONFIG=None):
mq_sender = MQSender()
mq_sender.send_event({'test': 'test'}, 'test')
mock_blocking_connection.basic_publish.assert_not_called()

# Test with MQ_CONFIG not a dictionary
with self.settings(MQ_CONFIG=[]):
with self.assertRaises(ValueError):
mq_sender = MQSender()

# Test all combinations of missing keys
for r in range(1, len(keys) + 1):
for combination in combinations(keys, r):
test_config = full_config.copy()
for key in combination:
del test_config[key]
with self.settings(MQ_CONFIG=test_config):
with self.assertRaises(ValueError):
mq_sender = MQSender()

@mock.patch('ssl.SSLContext.load_cert_chain')
@mock.patch('pika.BlockingConnection')
def test_send_event_with_ssl_and_declare(
self,
mock_blocking_connection,
mock_load_cert_chain
):
"""Test that SSL and declare options are handled correctly."""
mock_connection = mock.MagicMock()
mock_channel = mock.MagicMock()
mock_connection.channel.return_value = mock_channel
mock_blocking_connection.return_value = mock_connection

ssl_context = ssl.create_default_context()
ssl_options = SSLOptions(ssl_context, 'localhost')

with override_settings(MQ_CONFIG={
'username': 'test',
'password': 'test',
'host': 'localhost',
'ssl': ssl_options,
'declare': True,
'virtual_host': '/',
'exchange': 'test_exchange'
}):
mq_sender = MQSender()
mq_sender.send_event({'test': 'test'}, 'test_route')

# Assert that exchange_declare was called if 'declare' option is set to True
mock_channel.exchange_declare.assert_called_once_with(
exchange='test_exchange',
exchange_type='topic'
)

def _run_basic_publish_test(self, side_effects, expect_published):
with mock.patch('pika.BlockingConnection') as mock_blocking_connection:
mock_connection = mock.MagicMock()
mock_channel = mock.MagicMock()
mock_channel.basic_publish.side_effect = side_effects
mock_connection.channel.return_value = mock_channel
mock_blocking_connection.return_value = mock_connection

mq_sender = MQSender()
mq_sender.send_event({'test': 'test'}, 'test_route')

if expect_published:
self.assert_event_published_correctly(mq_sender)
else:
self.assertIsNone(mq_sender.mq_channel)

def test_send_event_basic_publish_raises_exceptions(self):
"""Test that exceptions in basic_publish are handled correctly.
We specifically handle ConnectionClosedByBroker and StreamLostError."""

# Test when exceptions are raised and no event is published
self._run_basic_publish_test([
ConnectionClosedByBroker(reply_code=123, reply_text='test'),
StreamLostError] * 6,
expect_published=False)

# Test when exceptions are raised, but eventually an event is published
self._run_basic_publish_test([
ConnectionClosedByBroker(reply_code=123, reply_text='test'),
StreamLostError,
mock.MagicMock()],
expect_published=True)

0 comments on commit e8424ee

Please sign in to comment.