diff --git a/src/tablib/core.py b/src/tablib/core.py index a153bf17..ec4bc646 100644 --- a/src/tablib/core.py +++ b/src/tablib/core.py @@ -22,17 +22,17 @@ from tablib.formats import registry from tablib.utils import normalize_input -__title__ = 'tablib' -__author__ = 'Kenneth Reitz' -__license__ = 'MIT' -__copyright__ = 'Copyright 2017 Kenneth Reitz. 2019 Jazzband.' -__docformat__ = 'restructuredtext' +__title__ = "tablib" +__author__ = "Kenneth Reitz" +__license__ = "MIT" +__copyright__ = "Copyright 2017 Kenneth Reitz. 2019 Jazzband." +__docformat__ = "restructuredtext" class Row: """Internal Row object. Mainly used for filtering.""" - __slots__ = ['_row', 'tags'] + __slots__ = ["_row", "tags"] def __init__(self, row=(), tags=()): self._row = list(row) @@ -75,7 +75,7 @@ def insert(self, index, value): self._row.insert(index, value) def __contains__(self, item): - return (item in self._row) + return item in self._row @property def tuple(self): @@ -93,7 +93,7 @@ def has_tag(self, tag): if tag is None: return False elif isinstance(tag, str): - return (tag in self.tags) + return tag in self.tags else: return bool(len(set(tag) & set(self.tags))) @@ -155,9 +155,9 @@ def __init__(self, *args, **kwargs): # (column, callback) tuples self._formatters = [] - self.headers = kwargs.get('headers') + self.headers = kwargs.get("headers") - self.title = kwargs.get('title') + self.title = kwargs.get("title") def __len__(self): return self.height @@ -182,14 +182,11 @@ def __setitem__(self, key, value): def __delitem__(self, key): if isinstance(key, str): - if key in self.headers: - pos = self.headers.index(key) del self.headers[pos] for i, row in enumerate(self._data): - del row[pos] self._data[i] = row else: @@ -199,9 +196,9 @@ def __delitem__(self, key): def __repr__(self): try: - return '<%s dataset>' % (self.title.lower()) + return "<%s dataset>" % (self.title.lower()) except AttributeError: - return '' + return "" def __str__(self): result = [] @@ -218,11 +215,11 @@ def __str__(self): # delimiter between header and data if self.__headers: - result.insert(1, ['-' * length for length in field_lens]) + result.insert(1, ["-" * length for length in field_lens]) - format_string = '|'.join('{%s:%s}' % item for item in enumerate(field_lens)) + format_string = "|".join("{%s:%s}" % item for item in enumerate(field_lens)) - return '\n'.join(format_string.format(*row) for row in result) + return "\n".join(format_string.format(*row) for row in result) # --------- # Internals @@ -280,7 +277,9 @@ def _package(self, dicts=True, ordered=True): if self.headers: if dicts: - data = [dict_pack(list(zip(self.headers, data_row))) for data_row in _data] + data = [ + dict_pack(list(zip(self.headers, data_row))) for data_row in _data + ] else: data = [list(self.headers)] + list(_data) else: @@ -373,8 +372,7 @@ def _clean_col(self, col): else: header = [] - if len(col) == 1 and hasattr(col[0], '__call__'): - + if len(col) == 1 and hasattr(col[0], "__call__"): col = list(map(col[0], self._data)) col = tuple(header + col) @@ -383,14 +381,14 @@ def _clean_col(self, col): @property def height(self): """The number of rows currently in the :class:`Dataset`. - Cannot be directly modified. + Cannot be directly modified. """ return len(self._data) @property def width(self): """The number of columns currently in the :class:`Dataset`. - Cannot be directly modified. + Cannot be directly modified. """ try: @@ -414,11 +412,11 @@ def load(self, in_stream, format=None, **kwargs): format = detect_format(stream) fmt = registry.get_format(format) - if not hasattr(fmt, 'import_set'): - raise UnsupportedFormat(f'Format {format} cannot be imported.') + if not hasattr(fmt, "import_set"): + raise UnsupportedFormat(f"Format {format} cannot be imported.") if not import_set: - raise UnsupportedFormat(f'Format {format} cannot be imported.') + raise UnsupportedFormat(f"Format {format} cannot be imported.") fmt.import_set(self, stream, **kwargs) return self @@ -430,8 +428,8 @@ def export(self, format, **kwargs): :param \\*\\*kwargs: (optional) custom configuration to the format `export_set`. """ fmt = registry.get_format(format) - if not hasattr(fmt, 'export_set'): - raise UnsupportedFormat(f'Format {format} cannot be exported.') + if not hasattr(fmt, "export_set"): + raise UnsupportedFormat(f"Format {format} cannot be exported.") return fmt.export_set(self, **kwargs) @@ -446,7 +444,7 @@ def insert(self, index, row, tags=()): The default behaviour is to insert the given row to the :class:`Dataset` object at the given index. - """ + """ self._validate(row) self._data.insert(index, Row(row, tags=tags)) @@ -501,10 +499,13 @@ def pop(self): return self.rpop() - def get(self, index): + def get(self, index: int): """Returns the row from the :class:`Dataset` at the given index.""" - return self[index] + if isinstance(index, int): + return self[index] + + raise TypeError("Index must be an integer.") # ------- # Columns @@ -543,7 +544,7 @@ def insert_col(self, index, col=None, header=None): col = [] # Callable Columns... - if hasattr(col, '__call__'): + if hasattr(col, "__call__"): col = list(map(col, self._data)) col = self._clean_col(col) @@ -561,9 +562,7 @@ def insert_col(self, index, col=None, header=None): self.headers.insert(index, header) if self.height and self.width: - for i, row in enumerate(self._data): - row.insert(index, col[i]) self._data[i] = row else: @@ -583,13 +582,13 @@ def lpush_col(self, col, header=None): self.insert_col(0, col, header=header) - def insert_separator(self, index, text='-'): + def insert_separator(self, index, text="-"): """Adds a separator to :class:`Dataset` at given index.""" sep = (index, text) self._separators.append(sep) - def append_separator(self, text='-'): + def append_separator(self, text="-"): """Adds a :ref:`separator ` to the :class:`Dataset`.""" # change offsets if headers are or aren't defined @@ -658,7 +657,6 @@ def sort(self, col, reverse=False): """ if isinstance(col, str): - if not self.headers: raise HeadersNeeded @@ -701,7 +699,6 @@ def transpose(self): _dset.headers = new_headers for index, column in enumerate(self.headers): - if column == self.headers[0]: # It's in the headers, so skip it continue @@ -773,7 +770,9 @@ def remove_duplicates(self): while maintaining the original order.""" seen = set() self._data[:] = [ - row for row in self._data if not (tuple(row) in seen or seen.add(tuple(row))) + row + for row in self._data + if not (tuple(row) in seen or seen.add(tuple(row))) ] def wipe(self): @@ -822,17 +821,16 @@ def subset(self, rows=None, cols=None): class Databook: - """A book of :class:`Dataset` objects. - """ + """A book of :class:`Dataset` objects.""" def __init__(self, sets=None): self._datasets = sets or [] def __repr__(self): try: - return '<%s databook>' % (self.title.lower()) + return "<%s databook>" % (self.title.lower()) except AttributeError: - return '' + return "" def wipe(self): """Removes all :class:`Dataset` objects from the :class:`Databook`.""" @@ -858,10 +856,9 @@ def _package(self, ordered=True): dict_pack = dict for dset in self._datasets: - collector.append(dict_pack( - title=dset.title, - data=dset._package(ordered=ordered) - )) + collector.append( + dict_pack(title=dset.title, data=dset._package(ordered=ordered)) + ) return collector @property @@ -882,8 +879,8 @@ def load(self, in_stream, format, **kwargs): format = detect_format(stream) fmt = registry.get_format(format) - if not hasattr(fmt, 'import_book'): - raise UnsupportedFormat(f'Format {format} cannot be loaded.') + if not hasattr(fmt, "import_book"): + raise UnsupportedFormat(f"Format {format} cannot be loaded.") fmt.import_book(self, stream, **kwargs) return self @@ -895,8 +892,8 @@ def export(self, format, **kwargs): :param \\*\\*kwargs: (optional) custom configuration to the format `export_book`. """ fmt = registry.get_format(format) - if not hasattr(fmt, 'export_book'): - raise UnsupportedFormat(f'Format {format} cannot be exported.') + if not hasattr(fmt, "export_book"): + raise UnsupportedFormat(f"Format {format} cannot be exported.") return fmt.export_book(self, **kwargs) @@ -913,7 +910,7 @@ def detect_format(stream): except AttributeError: pass finally: - if hasattr(stream, 'seek'): + if hasattr(stream, "seek"): stream.seek(0) return fmt_title diff --git a/tests/test_tablib.py b/tests/test_tablib.py index eeb4dd3b..13dea545 100755 --- a/tests/test_tablib.py +++ b/tests/test_tablib.py @@ -209,6 +209,9 @@ def test_get(self): with self.assertRaises(IndexError): self.founders.get(3) + with self.assertRaises(TypeError): + self.founders.get('first_name') + def test_get_col(self): """Verify getting columns by index"""