From b26e298f7822bf96204e1438042a43f185c580c5 Mon Sep 17 00:00:00 2001 From: WenChih Lo Date: Tue, 30 Apr 2024 21:03:45 +0800 Subject: [PATCH] Improvement: accept int value for Resources (#2196) Signed-off-by: Ryan Lo Co-authored-by: Kevin Su Signed-off-by: Jan Fiedler --- flytekit/core/resources.py | 41 +++++++++++++-------- flytekit/core/utils.py | 18 ++++----- tests/flytekit/unit/core/test_resources.py | 8 ++-- tests/flytekit/unit/core/test_type_hints.py | 2 +- 4 files changed, 39 insertions(+), 30 deletions(-) diff --git a/flytekit/core/resources.py b/flytekit/core/resources.py index 50cd68ecd0..8a99dbf2ea 100644 --- a/flytekit/core/resources.py +++ b/flytekit/core/resources.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Union from mashumaro.mixins.json import DataClassJSONMixin @@ -15,6 +15,7 @@ class Resources(DataClassJSONMixin): Resources(cpu="1", mem="2048") # This is 1 CPU and 2 KB of memory Resources(cpu="100m", mem="2Gi") # This is 1/10th of a CPU and 2 gigabytes of memory + Resources(cpu=0.5, mem=1024) # This is 500m CPU and 1 KB of memory # For Kubernetes-based tasks, pods use ephemeral local storage for scratch space, caching, and for logs. # This allocates 1Gi of such local storage. @@ -28,22 +29,28 @@ class Resources(DataClassJSONMixin): Also refer to the `K8s conventions. `__ """ - cpu: Optional[str] = None - mem: Optional[str] = None - gpu: Optional[str] = None - ephemeral_storage: Optional[str] = None + cpu: Optional[Union[str, int, float]] = None + mem: Optional[Union[str, int]] = None + gpu: Optional[Union[str, int]] = None + ephemeral_storage: Optional[Union[str, int]] = None def __post_init__(self): - def _check_none_or_str(value): + def _check_cpu(value): if value is None: return - if not isinstance(value, str): - raise AssertionError(f"{value} should be a string") + if not isinstance(value, (str, int, float)): + raise AssertionError(f"{value} should be of type str or int or float") - _check_none_or_str(self.cpu) - _check_none_or_str(self.mem) - _check_none_or_str(self.gpu) - _check_none_or_str(self.ephemeral_storage) + def _check_others(value): + if value is None: + return + if not isinstance(value, (str, int)): + raise AssertionError(f"{value} should be of type str or int") + + _check_cpu(self.cpu) + _check_others(self.mem) + _check_others(self.gpu) + _check_others(self.ephemeral_storage) @dataclass @@ -59,13 +66,15 @@ class ResourceSpec(DataClassJSONMixin): def _convert_resources_to_resource_entries(resources: Resources) -> List[_ResourceEntry]: # type: ignore resource_entries = [] if resources.cpu is not None: - resource_entries.append(_ResourceEntry(name=_ResourceName.CPU, value=resources.cpu)) + resource_entries.append(_ResourceEntry(name=_ResourceName.CPU, value=str(resources.cpu))) if resources.mem is not None: - resource_entries.append(_ResourceEntry(name=_ResourceName.MEMORY, value=resources.mem)) + resource_entries.append(_ResourceEntry(name=_ResourceName.MEMORY, value=str(resources.mem))) if resources.gpu is not None: - resource_entries.append(_ResourceEntry(name=_ResourceName.GPU, value=resources.gpu)) + resource_entries.append(_ResourceEntry(name=_ResourceName.GPU, value=str(resources.gpu))) if resources.ephemeral_storage is not None: - resource_entries.append(_ResourceEntry(name=_ResourceName.EPHEMERAL_STORAGE, value=resources.ephemeral_storage)) + resource_entries.append( + _ResourceEntry(name=_ResourceName.EPHEMERAL_STORAGE, value=str(resources.ephemeral_storage)) + ) return resource_entries diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index 0a50b66cc2..3106b3294e 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -7,7 +7,7 @@ from functools import wraps from hashlib import sha224 as _sha224 from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast from flyteidl.core import tasks_pb2 as _core_task @@ -62,14 +62,14 @@ def _get_container_definition( command: List[str], args: Optional[List[str]] = None, data_loading_config: Optional["task_models.DataLoadingConfig"] = None, - ephemeral_storage_request: Optional[str] = None, - cpu_request: Optional[str] = None, - gpu_request: Optional[str] = None, - memory_request: Optional[str] = None, - ephemeral_storage_limit: Optional[str] = None, - cpu_limit: Optional[str] = None, - gpu_limit: Optional[str] = None, - memory_limit: Optional[str] = None, + ephemeral_storage_request: Optional[Union[str, int]] = None, + cpu_request: Optional[Union[str, int, float]] = None, + gpu_request: Optional[Union[str, int]] = None, + memory_request: Optional[Union[str, int]] = None, + ephemeral_storage_limit: Optional[Union[str, int]] = None, + cpu_limit: Optional[Union[str, int, float]] = None, + gpu_limit: Optional[Union[str, int]] = None, + memory_limit: Optional[Union[str, int]] = None, environment: Optional[Dict[str, str]] = None, ) -> "task_models.Container": ephemeral_storage_limit = ephemeral_storage_limit diff --git a/tests/flytekit/unit/core/test_resources.py b/tests/flytekit/unit/core/test_resources.py index a6bbe359e6..5dd9926039 100644 --- a/tests/flytekit/unit/core/test_resources.py +++ b/tests/flytekit/unit/core/test_resources.py @@ -68,13 +68,13 @@ def test_convert_limits(resource_dict: Dict[str, str], expected_resource_name: _ def test_incorrect_type_resources(): with pytest.raises(AssertionError): - Resources(cpu=1) # type: ignore + Resources(cpu=bytes(1)) # type: ignore with pytest.raises(AssertionError): - Resources(mem=1) # type: ignore + Resources(mem=0.1) # type: ignore with pytest.raises(AssertionError): - Resources(gpu=1) # type: ignore + Resources(gpu=0.1) # type: ignore with pytest.raises(AssertionError): - Resources(ephemeral_storage=1) # type: ignore + Resources(ephemeral_storage=0.1) # type: ignore def test_resources_serialization(): diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 00410f74d6..3fa9579bd5 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -1983,7 +1983,7 @@ def t1(a: int) -> int: @workflow def my_wf(a: int) -> int: - return t1(a=a).with_overrides(requests=Resources(cpu=1)) # type: ignore + return t1(a=a).with_overrides(requests=Resources(cpu=1, mem=1.1)) # type: ignore with pytest.raises(AssertionError): my_wf(a=1)