diff --git a/flytekit/core/resources.py b/flytekit/core/resources.py index 107e627252..7c5040516d 100644 --- a/flytekit/core/resources.py +++ b/flytekit/core/resources.py @@ -13,7 +13,7 @@ class Resources(object): Resources(cpu="1", mem="2048") # This is 1 CPU and 2 KB of memory Resources(cpu="100m", mem="2Gi") # This is 1/10th of a CPU and 2 gigabytes of memory - Resources(cpu=2, mem=1024) # This is 2 CPU and 1 KB of memory + Resources(cpu=0.5, mem=1024) # This is 500m CPU and 1 KB of memory # For Kubernetes-based tasks, pods use ephemeral local storage for scratch space, caching, and for logs. # This allocates 1Gi of such local storage. @@ -27,22 +27,28 @@ class Resources(object): Also refer to the `K8s conventions. `__ """ - cpu: Optional[Union[str, int]] = None + cpu: Optional[Union[str, int, float]] = None mem: Optional[Union[str, int]] = None gpu: Optional[Union[str, int]] = None ephemeral_storage: Optional[Union[str, int]] = None def __post_init__(self): - def _check_none_or_str_or_int(value): + def _check_cpu(value): + if value is None: + return + if not isinstance(value, (str, int, float)): + raise AssertionError(f"{value} should be of type str or int or float") + + def _check_others(value): if value is None: return if not isinstance(value, (str, int)): - raise AssertionError(f"{value} should be a string or an integer") + raise AssertionError(f"{value} should be of type str or int") - _check_none_or_str_or_int(self.cpu) - _check_none_or_str_or_int(self.mem) - _check_none_or_str_or_int(self.gpu) - _check_none_or_str_or_int(self.ephemeral_storage) + _check_cpu(self.cpu) + _check_others(self.mem) + _check_others(self.gpu) + _check_others(self.ephemeral_storage) @dataclass diff --git a/tests/flytekit/unit/core/test_resources.py b/tests/flytekit/unit/core/test_resources.py index 465e2b19d4..0fcd31c756 100644 --- a/tests/flytekit/unit/core/test_resources.py +++ b/tests/flytekit/unit/core/test_resources.py @@ -68,7 +68,7 @@ def test_convert_limits(resource_dict: Dict[str, str], expected_resource_name: _ def test_incorrect_type_resources(): with pytest.raises(AssertionError): - Resources(cpu=0.1) # type: ignore + Resources(cpu=bytes(1)) # type: ignore with pytest.raises(AssertionError): Resources(mem=0.1) # type: ignore with pytest.raises(AssertionError):