Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-2657: [Python] Import TensorFlow python extension before pyarrow to avoid segfault #2210

Closed
wants to merge 11 commits into from
18 changes: 18 additions & 0 deletions ci/travis_script_manylinux.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,21 @@ pushd python/manylinux1
git clone ../../ arrow
docker build -t arrow-base-x86_64 -f Dockerfile-x86_64 .
docker run --shm-size=2g --rm -e PYARROW_PARALLEL=3 -v $PWD:/io arrow-base-x86_64 /io/build_arrow.sh

# Testing for https://issues.apache.org/jira/browse/ARROW-2657
# These tests cannot be run inside of the docker container, since TensorFlow
# does not run on manylinux1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this test fail before this fix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I tried it on an ec2 instance before putting it in.


source $TRAVIS_BUILD_DIR/ci/travis_env_common.sh

source $TRAVIS_BUILD_DIR/ci/travis_install_conda.sh

PYTHON_VERSION=3.6
CONDA_ENV_DIR=$TRAVIS_BUILD_DIR/pyarrow-test-$PYTHON_VERSION

conda create -y -q -p $CONDA_ENV_DIR python=$PYTHON_VERSION
source activate $CONDA_ENV_DIR

pip install -q tensorflow
pip install "dist/`ls dist/ | grep cp36`"
python -c "import pyarrow; import tensorflow"
7 changes: 7 additions & 0 deletions python/pyarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ def parse_version(root):
__version__ = None


import pyarrow.compat as compat


# Workaround for https://issues.apache.org/jira/browse/ARROW-2657
compat.import_tensorflow_extension()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be run only on Linux?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sorry I missed that (it was in the prior iteration of the patch but I didn't look closely enough at this function): https://issues.apache.org/jira/browse/ARROW-2795



from pyarrow.lib import cpu_count, set_cpu_count
from pyarrow.lib import (null, bool_,
int8, int16, int32, int64,
Expand Down
42 changes: 42 additions & 0 deletions python/pyarrow/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,48 @@ def encode_file_path(path):
# will convert utf8 to utf16
return encoded_path

def import_tensorflow_extension():
"""
Load the TensorFlow extension if it exists.

This is used to load the TensorFlow extension before
pyarrow.lib. If we don't do this there are symbol clashes
between TensorFlow's use of threading and our global
thread pool, see also
https://issues.apache.org/jira/browse/ARROW-2657 and
https://github.com/apache/arrow/pull/2096.
"""
import os
import site
tensorflow_loaded = False

# Try to load the tensorflow extension directly
# This is a performance optimization, tensorflow will always be
# loaded via the "import tensorflow" statement below if this
# doesn't succeed.
try:
site_paths = site.getsitepackages() + [site.getusersitepackages()]
except AttributeError:
# Workaround for https://github.com/pypa/virtualenv/issues/228,
# this happends in some configurations of virtualenv
site_paths = [os.path.dirname(site.__file__) + '/site-packages']
for site_path in site_paths:
ext = os.path.join(site_path, "tensorflow",
"libtensorflow_framework.so")
if os.path.exists(ext):
import ctypes
ctypes.CDLL(ext)
tensorflow_loaded = True
break

# If the above failed, try to load tensorflow the normal way
# (this is more expensive)
if not tensorflow_loaded:
try:
import tensorflow
except ImportError:
pass


integer_types = six.integer_types + (np.integer,)

Expand Down