Skip to content

Commit

Permalink
add type validation
Browse files Browse the repository at this point in the history
  • Loading branch information
yoonthegoon committed Aug 2, 2023
1 parent c45c50c commit b7fa506
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 52 deletions.
101 changes: 49 additions & 52 deletions src/tablib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@
from tablib.formats import registry
from tablib.utils import normalize_input

__title__ = 'tablib'
__author__ = 'Kenneth Reitz'
__license__ = 'MIT'
__copyright__ = 'Copyright 2017 Kenneth Reitz. 2019 Jazzband.'
__docformat__ = 'restructuredtext'
__title__ = "tablib"
__author__ = "Kenneth Reitz"
__license__ = "MIT"
__copyright__ = "Copyright 2017 Kenneth Reitz. 2019 Jazzband."
__docformat__ = "restructuredtext"


class Row:
"""Internal Row object. Mainly used for filtering."""

__slots__ = ['_row', 'tags']
__slots__ = ["_row", "tags"]

def __init__(self, row=(), tags=()):
self._row = list(row)
Expand Down Expand Up @@ -75,7 +75,7 @@ def insert(self, index, value):
self._row.insert(index, value)

def __contains__(self, item):
return (item in self._row)
return item in self._row

@property
def tuple(self):
Expand All @@ -93,7 +93,7 @@ def has_tag(self, tag):
if tag is None:
return False
elif isinstance(tag, str):
return (tag in self.tags)
return tag in self.tags
else:
return bool(len(set(tag) & set(self.tags)))

Expand Down Expand Up @@ -155,9 +155,9 @@ def __init__(self, *args, **kwargs):
# (column, callback) tuples
self._formatters = []

self.headers = kwargs.get('headers')
self.headers = kwargs.get("headers")

self.title = kwargs.get('title')
self.title = kwargs.get("title")

def __len__(self):
return self.height
Expand All @@ -182,14 +182,11 @@ def __setitem__(self, key, value):

def __delitem__(self, key):
if isinstance(key, str):

if key in self.headers:

pos = self.headers.index(key)
del self.headers[pos]

for i, row in enumerate(self._data):

del row[pos]
self._data[i] = row
else:
Expand All @@ -199,9 +196,9 @@ def __delitem__(self, key):

def __repr__(self):
try:
return '<%s dataset>' % (self.title.lower())
return "<%s dataset>" % (self.title.lower())
except AttributeError:
return '<dataset object>'
return "<dataset object>"

def __str__(self):
result = []
Expand All @@ -218,11 +215,11 @@ def __str__(self):

# delimiter between header and data
if self.__headers:
result.insert(1, ['-' * length for length in field_lens])
result.insert(1, ["-" * length for length in field_lens])

format_string = '|'.join('{%s:%s}' % item for item in enumerate(field_lens))
format_string = "|".join("{%s:%s}" % item for item in enumerate(field_lens))

return '\n'.join(format_string.format(*row) for row in result)
return "\n".join(format_string.format(*row) for row in result)

# ---------
# Internals
Expand Down Expand Up @@ -280,7 +277,9 @@ def _package(self, dicts=True, ordered=True):

if self.headers:
if dicts:
data = [dict_pack(list(zip(self.headers, data_row))) for data_row in _data]
data = [
dict_pack(list(zip(self.headers, data_row))) for data_row in _data
]
else:
data = [list(self.headers)] + list(_data)
else:
Expand Down Expand Up @@ -373,8 +372,7 @@ def _clean_col(self, col):
else:
header = []

if len(col) == 1 and hasattr(col[0], '__call__'):

if len(col) == 1 and hasattr(col[0], "__call__"):
col = list(map(col[0], self._data))
col = tuple(header + col)

Expand All @@ -383,14 +381,14 @@ def _clean_col(self, col):
@property
def height(self):
"""The number of rows currently in the :class:`Dataset`.
Cannot be directly modified.
Cannot be directly modified.
"""
return len(self._data)

@property
def width(self):
"""The number of columns currently in the :class:`Dataset`.
Cannot be directly modified.
Cannot be directly modified.
"""

try:
Expand All @@ -414,11 +412,11 @@ def load(self, in_stream, format=None, **kwargs):
format = detect_format(stream)

fmt = registry.get_format(format)
if not hasattr(fmt, 'import_set'):
raise UnsupportedFormat(f'Format {format} cannot be imported.')
if not hasattr(fmt, "import_set"):
raise UnsupportedFormat(f"Format {format} cannot be imported.")

if not import_set:
raise UnsupportedFormat(f'Format {format} cannot be imported.')
raise UnsupportedFormat(f"Format {format} cannot be imported.")

fmt.import_set(self, stream, **kwargs)
return self
Expand All @@ -430,8 +428,8 @@ def export(self, format, **kwargs):
:param \\*\\*kwargs: (optional) custom configuration to the format `export_set`.
"""
fmt = registry.get_format(format)
if not hasattr(fmt, 'export_set'):
raise UnsupportedFormat(f'Format {format} cannot be exported.')
if not hasattr(fmt, "export_set"):
raise UnsupportedFormat(f"Format {format} cannot be exported.")

return fmt.export_set(self, **kwargs)

Expand All @@ -446,7 +444,7 @@ def insert(self, index, row, tags=()):
The default behaviour is to insert the given row to the :class:`Dataset`
object at the given index.
"""
"""

self._validate(row)
self._data.insert(index, Row(row, tags=tags))
Expand Down Expand Up @@ -501,10 +499,13 @@ def pop(self):

return self.rpop()

def get(self, index):
def get(self, index: int):
"""Returns the row from the :class:`Dataset` at the given index."""

return self[index]
if isinstance(index, int):
return self[index]

raise TypeError("Index must be an integer.")

# -------
# Columns
Expand Down Expand Up @@ -543,7 +544,7 @@ def insert_col(self, index, col=None, header=None):
col = []

# Callable Columns...
if hasattr(col, '__call__'):
if hasattr(col, "__call__"):
col = list(map(col, self._data))

col = self._clean_col(col)
Expand All @@ -561,9 +562,7 @@ def insert_col(self, index, col=None, header=None):
self.headers.insert(index, header)

if self.height and self.width:

for i, row in enumerate(self._data):

row.insert(index, col[i])
self._data[i] = row
else:
Expand All @@ -583,13 +582,13 @@ def lpush_col(self, col, header=None):

self.insert_col(0, col, header=header)

def insert_separator(self, index, text='-'):
def insert_separator(self, index, text="-"):
"""Adds a separator to :class:`Dataset` at given index."""

sep = (index, text)
self._separators.append(sep)

def append_separator(self, text='-'):
def append_separator(self, text="-"):
"""Adds a :ref:`separator <separators>` to the :class:`Dataset`."""

# change offsets if headers are or aren't defined
Expand Down Expand Up @@ -658,7 +657,6 @@ def sort(self, col, reverse=False):
"""

if isinstance(col, str):

if not self.headers:
raise HeadersNeeded

Expand Down Expand Up @@ -701,7 +699,6 @@ def transpose(self):

_dset.headers = new_headers
for index, column in enumerate(self.headers):

if column == self.headers[0]:
# It's in the headers, so skip it
continue
Expand Down Expand Up @@ -773,7 +770,9 @@ def remove_duplicates(self):
while maintaining the original order."""
seen = set()
self._data[:] = [
row for row in self._data if not (tuple(row) in seen or seen.add(tuple(row)))
row
for row in self._data
if not (tuple(row) in seen or seen.add(tuple(row)))
]

def wipe(self):
Expand Down Expand Up @@ -822,17 +821,16 @@ def subset(self, rows=None, cols=None):


class Databook:
"""A book of :class:`Dataset` objects.
"""
"""A book of :class:`Dataset` objects."""

def __init__(self, sets=None):
self._datasets = sets or []

def __repr__(self):
try:
return '<%s databook>' % (self.title.lower())
return "<%s databook>" % (self.title.lower())
except AttributeError:
return '<databook object>'
return "<databook object>"

def wipe(self):
"""Removes all :class:`Dataset` objects from the :class:`Databook`."""
Expand All @@ -858,10 +856,9 @@ def _package(self, ordered=True):
dict_pack = dict

for dset in self._datasets:
collector.append(dict_pack(
title=dset.title,
data=dset._package(ordered=ordered)
))
collector.append(
dict_pack(title=dset.title, data=dset._package(ordered=ordered))
)
return collector

@property
Expand All @@ -882,8 +879,8 @@ def load(self, in_stream, format, **kwargs):
format = detect_format(stream)

fmt = registry.get_format(format)
if not hasattr(fmt, 'import_book'):
raise UnsupportedFormat(f'Format {format} cannot be loaded.')
if not hasattr(fmt, "import_book"):
raise UnsupportedFormat(f"Format {format} cannot be loaded.")

fmt.import_book(self, stream, **kwargs)
return self
Expand All @@ -895,8 +892,8 @@ def export(self, format, **kwargs):
:param \\*\\*kwargs: (optional) custom configuration to the format `export_book`.
"""
fmt = registry.get_format(format)
if not hasattr(fmt, 'export_book'):
raise UnsupportedFormat(f'Format {format} cannot be exported.')
if not hasattr(fmt, "export_book"):
raise UnsupportedFormat(f"Format {format} cannot be exported.")

return fmt.export_book(self, **kwargs)

Expand All @@ -913,7 +910,7 @@ def detect_format(stream):
except AttributeError:
pass
finally:
if hasattr(stream, 'seek'):
if hasattr(stream, "seek"):
stream.seek(0)
return fmt_title

Expand Down
3 changes: 3 additions & 0 deletions tests/test_tablib.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ def test_get(self):
with self.assertRaises(IndexError):
self.founders.get(3)

with self.assertRaises(TypeError):
self.founders.get('first_name')

def test_get_col(self):
"""Verify getting columns by index"""

Expand Down

0 comments on commit b7fa506

Please sign in to comment.