From 8a497a9a7c0f5614fbd1f022c9c283c1cd9c202b Mon Sep 17 00:00:00 2001
From: Bach Dániel <bd@ik.bme.hu>
Date: Mon, 29 Sep 2014 21:33:49 +0200
Subject: [PATCH] firewall: fix race condition in reloadtask_worker

Closes #311
---
 circle/firewall/tasks/local_tasks.py | 33 ++++++++++++++++++++-------------
 1 file changed, 20 insertions(+), 13 deletions(-)

diff --git a/circle/firewall/tasks/local_tasks.py b/circle/firewall/tasks/local_tasks.py
index 3085aca..5731a82 100644
--- a/circle/firewall/tasks/local_tasks.py
+++ b/circle/firewall/tasks/local_tasks.py
@@ -29,26 +29,24 @@ settings = django.conf.settings.FIREWALL_SETTINGS
 logger = getLogger(__name__)
 
 
-def _apply_once(name, queues, task, data):
+def _apply_once(name, tasks, queues, task, data):
     """Reload given networking component if needed.
     """
 
-    lockname = "%s_lock" % name
-    if not cache.get(lockname):
+    if name not in tasks:
         return
-    cache.delete(lockname)
 
     data = data()
     for queue in queues:
         try:
-            task.apply_async(args=data, queue=queue, expires=60).get(timeout=5)
+            task.apply_async(args=data, queue=queue, expires=60).get(timeout=2)
             logger.info("%s configuration is reloaded. (queue: %s)",
                         name, queue)
         except TimeoutError as e:
-            logger.critical('%s (queue: %s)', e, queue)
+            logger.critical('%s (queue: %s, task: %s)', e, queue, name)
         except:
-            logger.critical('Unhandled exception: queue: %s data: %s',
-                            queue, data, exc_info=True)
+            logger.critical('Unhandled exception: queue: %s data: %s task: %s',
+                            queue, data, name, exc_info=True)
 
 
 def get_firewall_queues():
@@ -68,19 +66,28 @@ def reloadtask_worker():
     from remote_tasks import (reload_dns, reload_dhcp, reload_firewall,
                               reload_firewall_vlan, reload_blacklist)
 
+    tasks = []
+    for i in ('dns', 'dhcp', 'firewall', 'firewall_vlan', 'blacklist'):
+        lockname = "%s_lock" % i
+        if cache.get(lockname):
+            tasks.append(i)
+        cache.delete(lockname)
+
+    logger.info("reloadtask_worker: Reload %s", ", ".join(tasks))
+
     firewall_queues = get_firewall_queues()
     dns_queues = [("%s.dns" % i) for i in
                   settings.get('dns_queues', [gethostname()])]
 
-    _apply_once('dns', dns_queues, reload_dns,
+    _apply_once('dns', tasks, dns_queues, reload_dns,
                 lambda: (dns(), ))
-    _apply_once('dhcp', firewall_queues, reload_dhcp,
+    _apply_once('dhcp', tasks, firewall_queues, reload_dhcp,
                 lambda: (dhcp(), ))
-    _apply_once('firewall', firewall_queues, reload_firewall,
+    _apply_once('firewall', tasks, firewall_queues, reload_firewall,
                 lambda: (BuildFirewall().build_ipt()))
-    _apply_once('firewall_vlan', firewall_queues, reload_firewall_vlan,
+    _apply_once('firewall_vlan', tasks, firewall_queues, reload_firewall_vlan,
                 lambda: (vlan(), ))
-    _apply_once('blacklist', firewall_queues, reload_blacklist,
+    _apply_once('blacklist', tasks, firewall_queues, reload_blacklist,
                 lambda: (list(ipset()), ))
 
 
--
libgit2 0.26.0