diff --git a/aiida/cmdline/commands/cmd_export.py b/aiida/cmdline/commands/cmd_export.py index b282d8a78a..add1f9641e 100644 --- a/aiida/cmdline/commands/cmd_export.py +++ b/aiida/cmdline/commands/cmd_export.py @@ -79,7 +79,9 @@ def inspect(archive, version, data, meta_data): @options.COMPUTERS() @options.GROUPS() @options.NODES() -@options.ARCHIVE_FORMAT() +@options.ARCHIVE_FORMAT( + type=click.Choice(['zip', 'zip-uncompressed', 'zip-lowmemory', 'tar.gz', 'null']), +) @options.FORCE(help='overwrite output file if it already exists') @click.option( '-v', @@ -101,6 +103,14 @@ def inspect(archive, version, data, meta_data): show_default=True, help='Include or exclude comments for node(s) in export. (Will also export extra users who commented).' ) +# will only be useful when moving to a new archive format, that does not store all data in memory +# @click.option( +# '-b', +# '--batch-size', +# default=1000, +# type=int, +# help='Batch database query results in sub-collections to reduce memory usage.' +# ) @decorators.with_dbenv() def create( output_file, codes, computers, groups, nodes, archive_format, force, input_calc_forward, input_work_forward, @@ -115,6 +125,7 @@ def create( their provenance, according to the rules outlined in the documentation. You can modify some of those rules using options of this command. """ + # pylint: disable=too-many-branches from aiida.common.log import override_log_formatter_context from aiida.common.progress_reporter import set_progress_bar_tqdm, set_progress_reporter from aiida.tools.importexport import export, ExportFileFormat, EXPORT_LOGGER @@ -143,17 +154,22 @@ def create( 'call_work_backward': call_work_backward, 'include_comments': include_comments, 'include_logs': include_logs, - 'overwrite': force + 'overwrite': force, } if archive_format == 'zip': export_format = ExportFileFormat.ZIP - kwargs.update({'use_compression': True}) + kwargs.update({'writer_init': {'use_compression': True}}) elif archive_format == 'zip-uncompressed': export_format = ExportFileFormat.ZIP - kwargs.update({'use_compression': False}) + kwargs.update({'writer_init': {'use_compression': False}}) + elif archive_format == 'zip-lowmemory': + export_format = ExportFileFormat.ZIP + kwargs.update({'writer_init': {'cache_zipinfo': True}}) elif archive_format == 'tar.gz': export_format = ExportFileFormat.TAR_GZIPPED + elif archive_format == 'null': + export_format = 'null' if verbosity in ['DEBUG', 'INFO']: set_progress_bar_tqdm(leave=(verbosity == 'DEBUG')) @@ -237,7 +253,10 @@ def migrate(input_file, output_file, force, silent, in_place, archive_format, ve except Exception as error: # pylint: disable=broad-except if verbosity == 'DEBUG': raise - echo.echo_critical(f'failed to migrate the archive file (use `--verbosity DEBUG` to see traceback): {error}') + echo.echo_critical( + 'failed to migrate the archive file (use `--verbosity DEBUG` to see traceback): ' + f'{error.__class__.__name__}:{error}' + ) if verbosity in ['DEBUG', 'INFO']: echo.echo_success(f'migrated the archive to version {version}') diff --git a/aiida/cmdline/commands/cmd_import.py b/aiida/cmdline/commands/cmd_import.py index 70a88ba26c..251006375e 100644 --- a/aiida/cmdline/commands/cmd_import.py +++ b/aiida/cmdline/commands/cmd_import.py @@ -207,7 +207,7 @@ def _import_archive(archive: str, web_based: bool, import_kwargs: dict, try_migr archive_path = archive if web_based: - echo.echo_info(f'downloading archive {archive}') + echo.echo_info(f'downloading archive: {archive}') try: response = urllib.request.urlopen(archive) except Exception as exception: @@ -216,6 +216,7 @@ def _import_archive(archive: str, web_based: bool, import_kwargs: dict, try_migr archive_path = temp_folder.get_abs_path('downloaded_archive.zip') echo.echo_success('archive downloaded, proceeding with import') + echo.echo_info(f'starting import: {archive}') try: import_data(archive_path, **import_kwargs) except IncompatibleArchiveVersionError as exception: diff --git a/aiida/tools/importexport/archive/common.py b/aiida/tools/importexport/archive/common.py index 83e32c728b..583115fbf5 100644 --- a/aiida/tools/importexport/archive/common.py +++ b/aiida/tools/importexport/archive/common.py @@ -15,17 +15,13 @@ from pathlib import Path import tarfile from types import TracebackType -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import zipfile from aiida.common import json # handles byte dumps from aiida.common.log import AIIDA_LOGGER -from aiida.tools.importexport.common.exceptions import CorruptArchive -__all__ = ( - 'ArchiveMetadata', 'detect_archive_type', 'null_callback', 'read_file_in_zip', 'read_file_in_tar', - 'safe_extract_zip', 'safe_extract_tar', 'CacheFolder' -) +__all__ = ('ArchiveMetadata', 'detect_archive_type', 'null_callback', 'CacheFolder') ARCHIVE_LOGGER = AIIDA_LOGGER.getChild('archive') @@ -47,7 +43,7 @@ class ArchiveMetadata: # optional data graph_traversal_rules: Optional[Dict[str, bool]] = dataclasses.field(default=None) # Entity type -> UUID list - entities_starting_set: Optional[Dict[str, Set[str]]] = dataclasses.field(default=None) + entities_starting_set: Optional[Dict[str, List[str]]] = dataclasses.field(default=None) include_comments: Optional[bool] = dataclasses.field(default=None) include_logs: Optional[bool] = dataclasses.field(default=None) # list of migration event notifications @@ -80,229 +76,6 @@ def detect_archive_type(in_path: str) -> str: ) -def read_file_in_zip(filepath: str, path: str) -> str: - """Read a text based file from inside a zip file and return its content. - - :param filepath: the path to the zip file - :param path: the relative path within the zip file - - """ - try: - return zipfile.ZipFile(filepath, 'r', allowZip64=True).read(path).decode('utf8') - except zipfile.BadZipfile as error: - raise CorruptArchive(f'The input file cannot be read: {error}') - except KeyError: - raise CorruptArchive(f'required file {path} is not included') - - -def read_file_in_tar(filepath: str, path: str) -> str: - """Read a text based file from inside a tar file and return its content. - - :param filepath: the path to the tar file - :param path: the relative path within the tar file - - """ - try: - with tarfile.open(filepath, 'r:*', format=tarfile.PAX_FORMAT) as handle: - result = handle.extractfile(path) - if result is None: - raise CorruptArchive(f'required file `{path}` is not included') - output = result.read() - if isinstance(output, bytes): - return output.decode('utf8') - except tarfile.ReadError: - raise ValueError('The input file format is not valid (not a tar file)') - except (KeyError, AttributeError): - raise CorruptArchive(f'required file `{path}` is not included') - - -def _get_filter(only_prefix: Iterable[str], ignore_prefix: Iterable[str]) -> Callable[[str], bool]: - """Create filter for members to extract. - - :param only_prefix: Extract only internal paths starting with these prefixes - :param ignore_prefix: Ignore internal paths starting with these prefixes - - """ - if only_prefix: - - def _filter(name): - return any(name.startswith(prefix) for prefix in only_prefix - ) and all(not name.startswith(prefix) for prefix in ignore_prefix) - else: - - def _filter(name): - return all(not name.startswith(prefix) for prefix in ignore_prefix) - - return _filter - - -def safe_extract_zip( - in_path: Union[str, Path], - out_path: Union[str, Path], - *, - only_prefix: Iterable[str] = (), - ignore_prefix: Iterable[str] = ('..', '/'), - callback: Callable[[str, Any], None] = null_callback, - callback_description: str = 'Extracting zip files' -): - """Safely extract a zip file - - :param in_path: Path to extract from - :param out_path: Path to extract to - :param only_prefix: Extract only internal paths starting with these prefixes - :param ignore_prefix: Ignore internal paths starting with these prefixes - :param callback: a callback to report on the process, ``callback(action, value)``, - with the following callback signatures: - - - ``callback('init', {'total': , 'description': })``, - to signal the start of a process, its total iterations and description - - ``callback('update', )``, - to signal an update to the process and the number of iterations to progress - - :param callback_description: the description to return in the callback - - :raises `~aiida.tools.importexport.common.exceptions.CorruptArchive`: if the file cannot be read - - """ - _filter = _get_filter(only_prefix, ignore_prefix) - try: - with zipfile.ZipFile(in_path, 'r', allowZip64=True) as handle: - callback('init', {'total': 1, 'description': 'Gathering list of files to extract from zip'}) - members = [name for name in handle.namelist() if _filter(name)] - callback('init', {'total': len(members), 'description': callback_description}) - for membername in members: - callback('update', 1) - handle.extract(path=os.path.abspath(out_path), member=membername) - except zipfile.BadZipfile as error: - raise CorruptArchive(f'The input file cannot be read: {error}') - - -def safe_extract_tar( - in_path: Union[str, Path], - out_path: Union[str, Path], - *, - only_prefix: Iterable[str] = (), - ignore_prefix: Iterable[str] = ('..', '/'), - callback: Callable[[str, Any], None] = null_callback, - callback_description: str = 'Extracting tar files' -): - """Safely extract a tar file - - :param in_path: Path to extract from - :param out_path: Path to extract to - :param only_prefix: Extract only internal paths starting with these prefixes - :param ignore_prefix: Ignore internal paths starting with these prefixes - :param callback: a callback to report on the process, ``callback(action, value)``, - with the following callback signatures: - - - ``callback('init', {'total': , 'description': })``, - to signal the start of a process, its total iterations and description - - ``callback('update', )``, - to signal an update to the process and the number of iterations to progress - - :param callback_description: the description to return in the callback - - :raises `~aiida.tools.importexport.common.exceptions.CorruptArchive`: if the file cannot be read - - """ - _filter = _get_filter(only_prefix, ignore_prefix) - try: - with tarfile.open(in_path, 'r:*', format=tarfile.PAX_FORMAT) as handle: - callback('init', {'total': 1, 'description': 'Computing tar objects to extract'}) - members = [m for m in handle.getmembers() if _filter(m.name)] - callback('init', {'total': len(members), 'description': callback_description}) - for member in members: - callback('update', 1) - if member.isdev(): - # safety: skip if character device, block device or FIFO - msg = f'WARNING, device found inside the tar file: {member.name}' - ARCHIVE_LOGGER.warning(msg) - continue - if member.issym() or member.islnk(): - # safety: skip symlinks - msg = f'WARNING, symlink found inside the tar file: {member.name}' - ARCHIVE_LOGGER.warning(msg) - continue - handle.extract(path=os.path.abspath(out_path), member=member) - except tarfile.ReadError as error: - raise CorruptArchive(f'The input file cannot be read: {error}') - - -def compress_folder_zip( - in_path: Union[str, Path], - out_path: Union[str, Path], - *, - compression: int = zipfile.ZIP_DEFLATED, - callback: Callable[[str, Any], None] = null_callback, - callback_description: str = 'Compressing objects as zip' -): - """Compress a folder as a zip file - - :param in_path: Path to compress - :param out_path: Path to compress to - :param compression: the compression type (see zipfile module) - :param callback: a callback to report on the process, ``callback(action, value)``, - with the following callback signatures: - - - ``callback('init', {'total': , 'description': })``, - to signal the start of a process, its total iterations and description - - ``callback('update', )``, - to signal an update to the process and the number of iterations to progress - - :param callback_description: the description to return in the callback - - """ - callback('init', {'total': 1, 'description': 'Computing objects to compress'}) - count = 0 - for _, dirnames, filenames in os.walk(in_path): - count += len(dirnames) + len(filenames) - callback('init', {'total': count, 'description': callback_description}) - with zipfile.ZipFile(out_path, mode='w', compression=compression, allowZip64=True) as archive: - for dirpath, dirnames, filenames in os.walk(in_path): - relpath = os.path.relpath(dirpath, in_path) - for filename in dirnames + filenames: - callback('update', 1) - real_src = os.path.join(dirpath, filename) - real_dest = os.path.join(relpath, filename) - archive.write(real_src, real_dest) - - -def compress_folder_tar( - in_path: Union[str, Path], - out_path: Union[str, Path], - *, - callback: Callable[[str, Any], None] = null_callback, - callback_description: str = 'Compressing objects as tar' -): - """Compress a folder as a zip file - - :param in_path: Path to compress - :param out_path: Path to compress to - :param callback: a callback to report on the process, ``callback(action, value)``, - with the following callback signatures: - - - ``callback('init', {'total': , 'description': })``, - to signal the start of a process, its total iterations and description - - ``callback('update', )``, - to signal an update to the process and the number of iterations to progress - - :param callback_description: the description to return in the callback - - """ - callback('init', {'total': 1, 'description': 'Computing objects to compress'}) - count = 0 - for _, dirnames, filenames in os.walk(in_path): - count += len(dirnames) + len(filenames) - callback('init', {'total': count + 1, 'description': callback_description}) - - def _filter(tarinfo): - callback('update', 1) - return tarinfo - - with tarfile.open(os.path.abspath(out_path), 'w:gz', format=tarfile.PAX_FORMAT, dereference=True) as archive: - archive.add(os.path.abspath(in_path), arcname='', filter=_filter) - - class CacheFolder: """A class to encapsulate a folder path with cached read/writes. diff --git a/aiida/tools/importexport/archive/migrations/v03_to_v04.py b/aiida/tools/importexport/archive/migrations/v03_to_v04.py index 55a3577484..b0c3fc97df 100644 --- a/aiida/tools/importexport/archive/migrations/v03_to_v04.py +++ b/aiida/tools/importexport/archive/migrations/v03_to_v04.py @@ -421,8 +421,8 @@ def add_extras(data): Since Extras were not available previously and usually only include hashes, the Node ids will be added, but included as empty dicts """ - node_extras = {} - node_extras_conversion = {} + node_extras: dict = {} + node_extras_conversion: dict = {} for node_id in data['export_data'].get('Node', {}): node_extras[node_id] = {} diff --git a/aiida/tools/importexport/archive/migrations/v05_to_v06.py b/aiida/tools/importexport/archive/migrations/v05_to_v06.py index 8f2ff472fe..b0be661591 100644 --- a/aiida/tools/importexport/archive/migrations/v05_to_v06.py +++ b/aiida/tools/importexport/archive/migrations/v05_to_v06.py @@ -24,6 +24,8 @@ Where id is a SQLA id and migration-name is the name of the particular migration. """ # pylint: disable=invalid-name +from typing import Union + from aiida.tools.importexport.archive.common import CacheFolder from .utils import verify_metadata_version, update_metadata @@ -33,6 +35,8 @@ def migrate_deserialized_datetime(data, conversion): """Deserialize datetime strings from export archives, meaning to reattach the UTC timezone information.""" from aiida.tools.importexport.common.exceptions import ArchiveMigrationError + ret_data: Union[str, dict, list] + if isinstance(data, dict): ret_data = {} for key, value in data.items(): diff --git a/aiida/tools/importexport/archive/migrators.py b/aiida/tools/importexport/archive/migrators.py index 15cc1def3d..1446cc7ef2 100644 --- a/aiida/tools/importexport/archive/migrators.py +++ b/aiida/tools/importexport/archive/migrators.py @@ -13,18 +13,18 @@ import os from pathlib import Path import shutil +import tarfile import tempfile from typing import Any, Callable, cast, List, Optional, Type, Union import zipfile +from archive_path import TarPath, ZipPath, read_file_in_tar, read_file_in_zip + from aiida.common.log import AIIDA_LOGGER from aiida.common.progress_reporter import get_progress_reporter, create_callback from aiida.tools.importexport.common.exceptions import (ArchiveMigrationError, CorruptArchive, DanglingLinkError) from aiida.tools.importexport.common.config import ExportFileFormat -from aiida.tools.importexport.archive.common import ( - read_file_in_tar, read_file_in_zip, safe_extract_tar, safe_extract_zip, compress_folder_tar, compress_folder_zip, - CacheFolder -) +from aiida.tools.importexport.archive.common import CacheFolder from aiida.tools.importexport.archive.migrations import MIGRATE_FUNCTIONS __all__ = ( @@ -227,39 +227,54 @@ def _extract_archive(self, filepath: Path, callback: Callable[[str, Any], None]) @staticmethod def _compress_archive_zip(in_path: Path, out_path: Path, compression: int): """Create a new zip compressed zip from a folder.""" - with get_progress_reporter()(total=1) as progress: + with get_progress_reporter()(total=1, desc='Compressing to zip') as progress: _callback = create_callback(progress) - compress_folder_zip(in_path, out_path, compression=compression, callback=_callback) + with ZipPath(out_path, mode='w', compression=compression, allow_zip64=True) as path: + path.puttree(in_path, check_exists=False, callback=_callback, cb_descript='Compressing to zip') @staticmethod def _compress_archive_tar(in_path: Path, out_path: Path): """Create a new zip compressed tar from a folder.""" - with get_progress_reporter()(total=1) as progress: + with get_progress_reporter()(total=1, desc='Compressing to tar') as progress: _callback = create_callback(progress) - compress_folder_tar(in_path, out_path, callback=_callback) + with TarPath(out_path, mode='w:gz', dereference=True) as path: + path.puttree(in_path, check_exists=False, callback=_callback, cb_descript='Compressing to tar') class ArchiveMigratorJsonZip(ArchiveMigratorJsonBase): """A migrator for a JSON zip compressed format.""" def _retrieve_version(self) -> str: - metadata = json.loads(read_file_in_zip(self.filepath, 'metadata.json')) + try: + metadata = json.loads(read_file_in_zip(self.filepath, 'metadata.json')) + except (IOError, FileNotFoundError) as error: + raise CorruptArchive(str(error)) if 'export_version' not in metadata: raise CorruptArchive("metadata.json doest not contain an 'export_version' key") return metadata['export_version'] def _extract_archive(self, filepath: Path, callback: Callable[[str, Any], None]): - safe_extract_zip(self.filepath, filepath, callback=callback) + try: + ZipPath(self.filepath, mode='r', allow_zip64=True).extract_tree(filepath, callback=callback) + except zipfile.BadZipfile as error: + raise CorruptArchive(f'The input file cannot be read: {error}') class ArchiveMigratorJsonTar(ArchiveMigratorJsonBase): """A migrator for a JSON tar compressed format.""" def _retrieve_version(self) -> str: - metadata = json.loads(read_file_in_tar(self.filepath, 'metadata.json')) + try: + metadata = json.loads(read_file_in_tar(self.filepath, 'metadata.json')) + except (IOError, FileNotFoundError) as error: + raise CorruptArchive(str(error)) if 'export_version' not in metadata: raise CorruptArchive("metadata.json doest not contain an 'export_version' key") return metadata['export_version'] def _extract_archive(self, filepath: Path, callback: Callable[[str, Any], None]): - safe_extract_tar(self.filepath, filepath, callback=callback) + try: + TarPath(self.filepath, mode='r:*', pax_format=tarfile.PAX_FORMAT + ).extract_tree(filepath, allow_dev=False, allow_symlink=False, callback=callback) + except tarfile.ReadError as error: + raise CorruptArchive(f'The input file cannot be read: {error}') diff --git a/aiida/tools/importexport/archive/readers.py b/aiida/tools/importexport/archive/readers.py index 3fe7c46380..65da299ab2 100644 --- a/aiida/tools/importexport/archive/readers.py +++ b/aiida/tools/importexport/archive/readers.py @@ -12,19 +12,20 @@ import json import os from pathlib import Path +import tarfile from types import TracebackType from typing import Any, Callable, cast, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type +import zipfile from distutils.version import StrictVersion +from archive_path import TarPath, ZipPath, read_file_in_tar, read_file_in_zip from aiida.common.log import AIIDA_LOGGER from aiida.common.exceptions import InvalidOperation from aiida.common.folders import Folder, SandboxFolder from aiida.tools.importexport.common.config import EXPORT_VERSION, ExportFileFormat, NODES_EXPORT_SUBFOLDER from aiida.tools.importexport.common.exceptions import (CorruptArchive, IncompatibleArchiveVersionError) -from aiida.tools.importexport.archive.common import ( - ArchiveMetadata, null_callback, read_file_in_tar, read_file_in_zip, safe_extract_zip, safe_extract_tar -) +from aiida.tools.importexport.archive.common import (ArchiveMetadata, null_callback) from aiida.tools.importexport.common.config import NODE_ENTITY_NAME, GROUP_ENTITY_NAME from aiida.tools.importexport.common.utils import export_shard_uuid @@ -367,7 +368,7 @@ def iter_node_repos( assert self._sandbox is not None # required by mypy # unarchive the common folder if it does not exist - common_prefix = os.path.commonprefix(path_prefixes) + common_prefix = os.path.commonpath(path_prefixes) if not self._sandbox.get_subfolder(common_prefix).exists(): self._extract(path_prefix=common_prefix, callback=callback) @@ -391,24 +392,31 @@ def file_format_verbose(self) -> str: def _get_metadata(self): if self._metadata is None: - self._metadata = json.loads(read_file_in_zip(self.filename, self.FILENAME_METADATA)) + try: + self._metadata = json.loads(read_file_in_zip(self.filename, self.FILENAME_METADATA)) + except (IOError, FileNotFoundError) as error: + raise CorruptArchive(str(error)) return self._metadata def _get_data(self): if self._data is None: - self._data = json.loads(read_file_in_zip(self.filename, self.FILENAME_DATA)) + try: + self._data = json.loads(read_file_in_zip(self.filename, self.FILENAME_DATA)) + except (IOError, FileNotFoundError) as error: + raise CorruptArchive(str(error)) return self._data def _extract(self, *, path_prefix: str, callback: Callable[[str, Any], None] = null_callback): self.assert_within_context() assert self._sandbox is not None # required by mypy - safe_extract_zip( - self.filename, - self._sandbox.abspath, - only_prefix=[path_prefix], - callback=callback, - callback_description='Extracting repository files' - ) + try: + ZipPath(self.filename, mode='r', allow_zip64=True).joinpath(path_prefix).extract_tree( + self._sandbox.abspath, callback=callback, cb_descript='Extracting repository files' + ) + except zipfile.BadZipfile as error: + raise CorruptArchive(f'The input file cannot be read: {error}') + except NotADirectoryError as error: + raise CorruptArchive(f'Unable to find required folder in archive: {error}') class ReaderJsonTar(ReaderJsonBase): @@ -420,24 +428,35 @@ def file_format_verbose(self) -> str: def _get_metadata(self): if self._metadata is None: - self._metadata = json.loads(read_file_in_tar(self.filename, self.FILENAME_METADATA)) + try: + self._metadata = json.loads(read_file_in_tar(self.filename, self.FILENAME_METADATA)) + except (IOError, FileNotFoundError) as error: + raise CorruptArchive(str(error)) return self._metadata def _get_data(self): if self._data is None: - self._data = json.loads(read_file_in_tar(self.filename, self.FILENAME_DATA)) + try: + self._data = json.loads(read_file_in_tar(self.filename, self.FILENAME_DATA)) + except (IOError, FileNotFoundError) as error: + raise CorruptArchive(str(error)) return self._data def _extract(self, *, path_prefix: str, callback: Callable[[str, Any], None] = null_callback): self.assert_within_context() assert self._sandbox is not None # required by mypy - safe_extract_tar( - self.filename, - self._sandbox.abspath, - only_prefix=[path_prefix], - callback=callback, - callback_description='Extracting repository files' - ) + try: + TarPath(self.filename, mode='r:*').joinpath(path_prefix).extract_tree( + self._sandbox.abspath, + allow_dev=False, + allow_symlink=False, + callback=callback, + cb_descript='Extracting repository files' + ) + except tarfile.ReadError as error: + raise CorruptArchive(f'The input file cannot be read: {error}') + except NotADirectoryError as error: + raise CorruptArchive(f'Unable to find required folder in archive: {error}') class ReaderJsonFolder(ReaderJsonBase): diff --git a/aiida/tools/importexport/archive/writers.py b/aiida/tools/importexport/archive/writers.py index 814795cc0a..a4d96ba7e2 100644 --- a/aiida/tools/importexport/archive/writers.py +++ b/aiida/tools/importexport/archive/writers.py @@ -9,20 +9,29 @@ ########################################################################### """Archive writer classes.""" from abc import ABC, abstractmethod -from dataclasses import dataclass -import json +from copy import deepcopy +import os +from pathlib import Path +import shelve +import shutil import time -import tarfile -from typing import Any, Callable, cast, Dict, Iterable, List, Set, Tuple, Type, Union +import tempfile +from types import TracebackType +from typing import Any, cast, Dict, List, Optional, Type, Union +import zipfile -from aiida.common.folders import Folder, SandboxFolder -from aiida.common.progress_reporter import get_progress_reporter +from archive_path import TarPath, ZipPath + +from aiida.common import json +from aiida.common.exceptions import InvalidOperation +from aiida.common.folders import Folder from aiida.tools.importexport.archive.common import ArchiveMetadata -from aiida.tools.importexport.common.config import EXPORT_VERSION, NODES_EXPORT_SUBFOLDER, ExportFileFormat +from aiida.tools.importexport.common.config import ( + EXPORT_VERSION, NODE_ENTITY_NAME, NODES_EXPORT_SUBFOLDER, ExportFileFormat +) from aiida.tools.importexport.common.utils import export_shard_uuid -from aiida.tools.importexport.common.zip_folder import ZipFolder -__all__ = ('ArchiveData', 'ArchiveWriterAbstract', 'get_writer', 'WriterJsonFolder', 'WriterJsonTar', 'WriterJsonZip') +__all__ = ('ArchiveWriterAbstract', 'get_writer', 'WriterJsonZip', 'WriterJsonTar', 'WriterJsonFolder') def get_writer(file_format: str) -> Type['ArchiveWriterAbstract']: @@ -31,229 +40,380 @@ def get_writer(file_format: str) -> Type['ArchiveWriterAbstract']: ExportFileFormat.ZIP: WriterJsonZip, ExportFileFormat.TAR_GZIPPED: WriterJsonTar, 'folder': WriterJsonFolder, + 'null': WriterNull, } if file_format not in writers: raise ValueError( - f'Can only write in the formats: {tuple(writers.keys())}, please specify one for "file_format".' + f'Can only write in the formats: {tuple(writers.keys())}, not {file_format}, ' + 'please specify one for "file_format".' ) return cast(Type[ArchiveWriterAbstract], writers[file_format]) -@dataclass -class ArchiveData: - """Class for storing data, to export to an AiiDA archive.""" - metadata: ArchiveMetadata - node_uuids: Set[str] - # UUID of the group -> UUIDs of the entities it contains - group_uuids: Dict[str, Set[str]] - # list of {'input': , 'output': , 'label':