Skip to content

Commit

Permalink
Archive export refactor (2) (#4534)
Browse files Browse the repository at this point in the history
This PR builds on #4448,
with the goal of improving both the export writer API
(allowing for "streamed" data writing)
and performance of the export process (CPU and memory usage).

The writer is now used as a context manager,
rather than passing all data to it after extraction of the data from the AiiDA database.
This means it is called throughout the export process,
and will allow for less data to be kept in RAM when moving to a new archive format.

The number of database queries has also been reduced, resulting in a faster process.

Lastly, code for read/writes to the archive has been moved to the https://github.com/aiidateam/archive-path package.
This standardises the interface for both zip and tar, and
especially for export to tar, provides much improved performance,
since the data is now written directly to the archive
(rather than writing to a folder then only compressing at the end).

Co-authored-by: Leopold Talirz <leopold.talirz@gmail.com>
  • Loading branch information
chrisjsewell and ltalirz committed Nov 12, 2020
1 parent 008580e commit bd197f3
Show file tree
Hide file tree
Showing 26 changed files with 948 additions and 1,255 deletions.
29 changes: 24 additions & 5 deletions aiida/cmdline/commands/cmd_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def inspect(archive, version, data, meta_data):
@options.COMPUTERS()
@options.GROUPS()
@options.NODES()
@options.ARCHIVE_FORMAT()
@options.ARCHIVE_FORMAT(
type=click.Choice(['zip', 'zip-uncompressed', 'zip-lowmemory', 'tar.gz', 'null']),
)
@options.FORCE(help='overwrite output file if it already exists')
@click.option(
'-v',
Expand All @@ -101,6 +103,14 @@ def inspect(archive, version, data, meta_data):
show_default=True,
help='Include or exclude comments for node(s) in export. (Will also export extra users who commented).'
)
# will only be useful when moving to a new archive format, that does not store all data in memory
# @click.option(
# '-b',
# '--batch-size',
# default=1000,
# type=int,
# help='Batch database query results in sub-collections to reduce memory usage.'
# )
@decorators.with_dbenv()
def create(
output_file, codes, computers, groups, nodes, archive_format, force, input_calc_forward, input_work_forward,
Expand All @@ -115,6 +125,7 @@ def create(
their provenance, according to the rules outlined in the documentation.
You can modify some of those rules using options of this command.
"""
# pylint: disable=too-many-branches
from aiida.common.log import override_log_formatter_context
from aiida.common.progress_reporter import set_progress_bar_tqdm, set_progress_reporter
from aiida.tools.importexport import export, ExportFileFormat, EXPORT_LOGGER
Expand Down Expand Up @@ -143,17 +154,22 @@ def create(
'call_work_backward': call_work_backward,
'include_comments': include_comments,
'include_logs': include_logs,
'overwrite': force
'overwrite': force,
}

if archive_format == 'zip':
export_format = ExportFileFormat.ZIP
kwargs.update({'use_compression': True})
kwargs.update({'writer_init': {'use_compression': True}})
elif archive_format == 'zip-uncompressed':
export_format = ExportFileFormat.ZIP
kwargs.update({'use_compression': False})
kwargs.update({'writer_init': {'use_compression': False}})
elif archive_format == 'zip-lowmemory':
export_format = ExportFileFormat.ZIP
kwargs.update({'writer_init': {'cache_zipinfo': True}})
elif archive_format == 'tar.gz':
export_format = ExportFileFormat.TAR_GZIPPED
elif archive_format == 'null':
export_format = 'null'

if verbosity in ['DEBUG', 'INFO']:
set_progress_bar_tqdm(leave=(verbosity == 'DEBUG'))
Expand Down Expand Up @@ -237,7 +253,10 @@ def migrate(input_file, output_file, force, silent, in_place, archive_format, ve
except Exception as error: # pylint: disable=broad-except
if verbosity == 'DEBUG':
raise
echo.echo_critical(f'failed to migrate the archive file (use `--verbosity DEBUG` to see traceback): {error}')
echo.echo_critical(
'failed to migrate the archive file (use `--verbosity DEBUG` to see traceback): '
f'{error.__class__.__name__}:{error}'
)

if verbosity in ['DEBUG', 'INFO']:
echo.echo_success(f'migrated the archive to version {version}')
3 changes: 2 additions & 1 deletion aiida/cmdline/commands/cmd_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _import_archive(archive: str, web_based: bool, import_kwargs: dict, try_migr
archive_path = archive

if web_based:
echo.echo_info(f'downloading archive {archive}')
echo.echo_info(f'downloading archive: {archive}')
try:
response = urllib.request.urlopen(archive)
except Exception as exception:
Expand All @@ -216,6 +216,7 @@ def _import_archive(archive: str, web_based: bool, import_kwargs: dict, try_migr
archive_path = temp_folder.get_abs_path('downloaded_archive.zip')
echo.echo_success('archive downloaded, proceeding with import')

echo.echo_info(f'starting import: {archive}')
try:
import_data(archive_path, **import_kwargs)
except IncompatibleArchiveVersionError as exception:
Expand Down
233 changes: 3 additions & 230 deletions aiida/tools/importexport/archive/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,13 @@
from pathlib import Path
import tarfile
from types import TracebackType
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import zipfile

from aiida.common import json # handles byte dumps
from aiida.common.log import AIIDA_LOGGER
from aiida.tools.importexport.common.exceptions import CorruptArchive

__all__ = (
'ArchiveMetadata', 'detect_archive_type', 'null_callback', 'read_file_in_zip', 'read_file_in_tar',
'safe_extract_zip', 'safe_extract_tar', 'CacheFolder'
)
__all__ = ('ArchiveMetadata', 'detect_archive_type', 'null_callback', 'CacheFolder')

ARCHIVE_LOGGER = AIIDA_LOGGER.getChild('archive')

Expand All @@ -47,7 +43,7 @@ class ArchiveMetadata:
# optional data
graph_traversal_rules: Optional[Dict[str, bool]] = dataclasses.field(default=None)
# Entity type -> UUID list
entities_starting_set: Optional[Dict[str, Set[str]]] = dataclasses.field(default=None)
entities_starting_set: Optional[Dict[str, List[str]]] = dataclasses.field(default=None)
include_comments: Optional[bool] = dataclasses.field(default=None)
include_logs: Optional[bool] = dataclasses.field(default=None)
# list of migration event notifications
Expand Down Expand Up @@ -80,229 +76,6 @@ def detect_archive_type(in_path: str) -> str:
)


def read_file_in_zip(filepath: str, path: str) -> str:
"""Read a text based file from inside a zip file and return its content.
:param filepath: the path to the zip file
:param path: the relative path within the zip file
"""
try:
return zipfile.ZipFile(filepath, 'r', allowZip64=True).read(path).decode('utf8')
except zipfile.BadZipfile as error:
raise CorruptArchive(f'The input file cannot be read: {error}')
except KeyError:
raise CorruptArchive(f'required file {path} is not included')


def read_file_in_tar(filepath: str, path: str) -> str:
"""Read a text based file from inside a tar file and return its content.
:param filepath: the path to the tar file
:param path: the relative path within the tar file
"""
try:
with tarfile.open(filepath, 'r:*', format=tarfile.PAX_FORMAT) as handle:
result = handle.extractfile(path)
if result is None:
raise CorruptArchive(f'required file `{path}` is not included')
output = result.read()
if isinstance(output, bytes):
return output.decode('utf8')
except tarfile.ReadError:
raise ValueError('The input file format is not valid (not a tar file)')
except (KeyError, AttributeError):
raise CorruptArchive(f'required file `{path}` is not included')


def _get_filter(only_prefix: Iterable[str], ignore_prefix: Iterable[str]) -> Callable[[str], bool]:
"""Create filter for members to extract.
:param only_prefix: Extract only internal paths starting with these prefixes
:param ignore_prefix: Ignore internal paths starting with these prefixes
"""
if only_prefix:

def _filter(name):
return any(name.startswith(prefix) for prefix in only_prefix
) and all(not name.startswith(prefix) for prefix in ignore_prefix)
else:

def _filter(name):
return all(not name.startswith(prefix) for prefix in ignore_prefix)

return _filter


def safe_extract_zip(
in_path: Union[str, Path],
out_path: Union[str, Path],
*,
only_prefix: Iterable[str] = (),
ignore_prefix: Iterable[str] = ('..', '/'),
callback: Callable[[str, Any], None] = null_callback,
callback_description: str = 'Extracting zip files'
):
"""Safely extract a zip file
:param in_path: Path to extract from
:param out_path: Path to extract to
:param only_prefix: Extract only internal paths starting with these prefixes
:param ignore_prefix: Ignore internal paths starting with these prefixes
:param callback: a callback to report on the process, ``callback(action, value)``,
with the following callback signatures:
- ``callback('init', {'total': <int>, 'description': <str>})``,
to signal the start of a process, its total iterations and description
- ``callback('update', <int>)``,
to signal an update to the process and the number of iterations to progress
:param callback_description: the description to return in the callback
:raises `~aiida.tools.importexport.common.exceptions.CorruptArchive`: if the file cannot be read
"""
_filter = _get_filter(only_prefix, ignore_prefix)
try:
with zipfile.ZipFile(in_path, 'r', allowZip64=True) as handle:
callback('init', {'total': 1, 'description': 'Gathering list of files to extract from zip'})
members = [name for name in handle.namelist() if _filter(name)]
callback('init', {'total': len(members), 'description': callback_description})
for membername in members:
callback('update', 1)
handle.extract(path=os.path.abspath(out_path), member=membername)
except zipfile.BadZipfile as error:
raise CorruptArchive(f'The input file cannot be read: {error}')


def safe_extract_tar(
in_path: Union[str, Path],
out_path: Union[str, Path],
*,
only_prefix: Iterable[str] = (),
ignore_prefix: Iterable[str] = ('..', '/'),
callback: Callable[[str, Any], None] = null_callback,
callback_description: str = 'Extracting tar files'
):
"""Safely extract a tar file
:param in_path: Path to extract from
:param out_path: Path to extract to
:param only_prefix: Extract only internal paths starting with these prefixes
:param ignore_prefix: Ignore internal paths starting with these prefixes
:param callback: a callback to report on the process, ``callback(action, value)``,
with the following callback signatures:
- ``callback('init', {'total': <int>, 'description': <str>})``,
to signal the start of a process, its total iterations and description
- ``callback('update', <int>)``,
to signal an update to the process and the number of iterations to progress
:param callback_description: the description to return in the callback
:raises `~aiida.tools.importexport.common.exceptions.CorruptArchive`: if the file cannot be read
"""
_filter = _get_filter(only_prefix, ignore_prefix)
try:
with tarfile.open(in_path, 'r:*', format=tarfile.PAX_FORMAT) as handle:
callback('init', {'total': 1, 'description': 'Computing tar objects to extract'})
members = [m for m in handle.getmembers() if _filter(m.name)]
callback('init', {'total': len(members), 'description': callback_description})
for member in members:
callback('update', 1)
if member.isdev():
# safety: skip if character device, block device or FIFO
msg = f'WARNING, device found inside the tar file: {member.name}'
ARCHIVE_LOGGER.warning(msg)
continue
if member.issym() or member.islnk():
# safety: skip symlinks
msg = f'WARNING, symlink found inside the tar file: {member.name}'
ARCHIVE_LOGGER.warning(msg)
continue
handle.extract(path=os.path.abspath(out_path), member=member)
except tarfile.ReadError as error:
raise CorruptArchive(f'The input file cannot be read: {error}')


def compress_folder_zip(
in_path: Union[str, Path],
out_path: Union[str, Path],
*,
compression: int = zipfile.ZIP_DEFLATED,
callback: Callable[[str, Any], None] = null_callback,
callback_description: str = 'Compressing objects as zip'
):
"""Compress a folder as a zip file
:param in_path: Path to compress
:param out_path: Path to compress to
:param compression: the compression type (see zipfile module)
:param callback: a callback to report on the process, ``callback(action, value)``,
with the following callback signatures:
- ``callback('init', {'total': <int>, 'description': <str>})``,
to signal the start of a process, its total iterations and description
- ``callback('update', <int>)``,
to signal an update to the process and the number of iterations to progress
:param callback_description: the description to return in the callback
"""
callback('init', {'total': 1, 'description': 'Computing objects to compress'})
count = 0
for _, dirnames, filenames in os.walk(in_path):
count += len(dirnames) + len(filenames)
callback('init', {'total': count, 'description': callback_description})
with zipfile.ZipFile(out_path, mode='w', compression=compression, allowZip64=True) as archive:
for dirpath, dirnames, filenames in os.walk(in_path):
relpath = os.path.relpath(dirpath, in_path)
for filename in dirnames + filenames:
callback('update', 1)
real_src = os.path.join(dirpath, filename)
real_dest = os.path.join(relpath, filename)
archive.write(real_src, real_dest)


def compress_folder_tar(
in_path: Union[str, Path],
out_path: Union[str, Path],
*,
callback: Callable[[str, Any], None] = null_callback,
callback_description: str = 'Compressing objects as tar'
):
"""Compress a folder as a zip file
:param in_path: Path to compress
:param out_path: Path to compress to
:param callback: a callback to report on the process, ``callback(action, value)``,
with the following callback signatures:
- ``callback('init', {'total': <int>, 'description': <str>})``,
to signal the start of a process, its total iterations and description
- ``callback('update', <int>)``,
to signal an update to the process and the number of iterations to progress
:param callback_description: the description to return in the callback
"""
callback('init', {'total': 1, 'description': 'Computing objects to compress'})
count = 0
for _, dirnames, filenames in os.walk(in_path):
count += len(dirnames) + len(filenames)
callback('init', {'total': count + 1, 'description': callback_description})

def _filter(tarinfo):
callback('update', 1)
return tarinfo

with tarfile.open(os.path.abspath(out_path), 'w:gz', format=tarfile.PAX_FORMAT, dereference=True) as archive:
archive.add(os.path.abspath(in_path), arcname='', filter=_filter)


class CacheFolder:
"""A class to encapsulate a folder path with cached read/writes.
Expand Down
4 changes: 2 additions & 2 deletions aiida/tools/importexport/archive/migrations/v03_to_v04.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,8 +421,8 @@ def add_extras(data):
Since Extras were not available previously and usually only include hashes,
the Node ids will be added, but included as empty dicts
"""
node_extras = {}
node_extras_conversion = {}
node_extras: dict = {}
node_extras_conversion: dict = {}

for node_id in data['export_data'].get('Node', {}):
node_extras[node_id] = {}
Expand Down
4 changes: 4 additions & 0 deletions aiida/tools/importexport/archive/migrations/v05_to_v06.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
Where id is a SQLA id and migration-name is the name of the particular migration.
"""
# pylint: disable=invalid-name
from typing import Union

from aiida.tools.importexport.archive.common import CacheFolder

from .utils import verify_metadata_version, update_metadata
Expand All @@ -33,6 +35,8 @@ def migrate_deserialized_datetime(data, conversion):
"""Deserialize datetime strings from export archives, meaning to reattach the UTC timezone information."""
from aiida.tools.importexport.common.exceptions import ArchiveMigrationError

ret_data: Union[str, dict, list]

if isinstance(data, dict):
ret_data = {}
for key, value in data.items():
Expand Down
Loading

0 comments on commit bd197f3

Please sign in to comment.