Skip to content

Commit

Permalink
[MRG+2] Back out threading.local in favor of global stepwise context (#…
Browse files Browse the repository at this point in the history
…275)

[HOTFIX] 🔥 Back out threading.local management of context store

Use a global singleton to manage context store rather than a threading
local, since #271 seems to indicate a problem in sharing data between
threads in job schedulers.
  • Loading branch information
tgsmith61591 authored Dec 17, 2019
1 parent a7b1a3e commit 0a4ec38
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 13 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ v0.8.1) will document the latest features.

* Fix bug where 1.5.1 documentation was labeled version "0.0.0".

* Fix bug reported in `#271 <https://github.com/alkaline-ml/pmdarima/issues/271>`_, where
the use of ``threading.local`` to store stepwise context information may have broken
job schedulers.

* Fix bug reported in `#272 <https://github.com/alkaline-ml/pmdarima/issues/272>`_, where
the new default value of ``max_order`` can cause a ``ValueError`` even in default cases
when ``stepwise=False``.
Expand Down
27 changes: 15 additions & 12 deletions pmdarima/arima/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,24 @@
# Author: Krishna Sunkara (kpsunkara)
#
# Re-entrant, reusable context manager to store execution context. Introduced
# in pmdarima 1.5.0 (see #221)
# in pmdarima 1.5.0 (see #221), redesigned not to use thread locals in #273
# (see #275 for context).

import threading
from abc import ABC, abstractmethod
from enum import Enum
import collections

# thread local value to store the context info
_ctx = threading.local()
_ctx.store = {}

__all__ = ['AbstractContext', 'ContextStore', 'ContextType']


class _CtxSingleton:
"""Singleton class to store context information"""
store = {}


_ctx = _CtxSingleton()


class ContextType(Enum):
"""Context Type Enumeration
Expand All @@ -30,9 +34,8 @@ class AbstractContext(ABC):
"""An abstract context manager to store execution context.
A generic, re-entrant, reusable context manager to store
execution context in a threading.local instance. Has helper
methods to iterate over the context info and provide a
string representation of the context info.
execution context. Has helper methods to iterate over the context info
and provide a string representation of the context info.
"""
def __init__(self, **kwargs):
# remove None valued entries,
Expand Down Expand Up @@ -93,10 +96,10 @@ def get_type(self):


class ContextStore:
"""A class to wrap access to threading.local() context store
"""A class to wrap access to the global context store
This class hosts static methods to wrap access to and
encapsulate the threading.local() instance
This class hosts static methods to wrap access to and encapsulate the
singleton content store instance
"""
@staticmethod
def get_context(context_type):
Expand Down
32 changes: 32 additions & 0 deletions pmdarima/arima/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from pmdarima.arima.auto import StepwiseContext, auto_arima
from pmdarima.arima._context import ContextStore, ContextType
from pmdarima.arima import _context as context_lib
from pmdarima.datasets import load_lynx, load_wineind
from unittest import mock
import threading
import collections
import pytest

lynx = load_lynx()
Expand Down Expand Up @@ -105,3 +109,31 @@ def test_add_get_remove_context_args():

with pytest.raises(ValueError):
ContextStore.get_context(None)


def test_context_store_accessible_across_threads():
# Make sure it's completely empty by patching it
d = {}
with mock.patch('pmdarima.arima._context._ctx.store', d):

# pushes onto the Context Store
def push(n):
# n is the number of times this has been executed before. If > 0,
# assert there is a context there
if n > 0:
assert len(context_lib._ctx.store[ContextType.STEPWISE]) == n
else:
context_lib._ctx.store[ContextType.STEPWISE] = \
collections.deque()

new_ctx = StepwiseContext()
context_lib._ctx.store[ContextType.STEPWISE].append(new_ctx)
assert len(context_lib._ctx.store[ContextType.STEPWISE]) == n + 1

for i in range(5):
t = threading.Thread(target=push, args=(i,))
t.start()
t.join(1) # it shouldn't take even close to this time

# Assert the mock has lifted
assert context_lib._ctx.store is not d
2 changes: 1 addition & 1 deletion pmdarima/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def tsdisplay(y, lag_max=50, figsize=(8, 6), title=None, bins=25,

# ax2 is simply the histogram
hist_kwargs = {} if not hist_kwargs else hist_kwargs
_ = ax2.hist(y, bins=bins, **hist_kwargs)
_ = ax2.hist(y, bins=bins, **hist_kwargs) # noqa
ax2.set_title("Frequency")

fig.tight_layout()
Expand Down

0 comments on commit 0a4ec38

Please sign in to comment.