Skip to content

Commit

Permalink
Improvement: accept int value for Resources (#2196)
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Lo <wenchih@apache.org>
Co-authored-by: Kevin Su <pingsutw@gmail.com>
Signed-off-by: Jan Fiedler <jan@union.ai>
  • Loading branch information
2 people authored and fiedlerNr9 committed Jul 25, 2024
1 parent 1fa55e4 commit b26e298
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 30 deletions.
41 changes: 25 additions & 16 deletions flytekit/core/resources.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Union

from mashumaro.mixins.json import DataClassJSONMixin

Expand All @@ -15,6 +15,7 @@ class Resources(DataClassJSONMixin):
Resources(cpu="1", mem="2048") # This is 1 CPU and 2 KB of memory
Resources(cpu="100m", mem="2Gi") # This is 1/10th of a CPU and 2 gigabytes of memory
Resources(cpu=0.5, mem=1024) # This is 500m CPU and 1 KB of memory
# For Kubernetes-based tasks, pods use ephemeral local storage for scratch space, caching, and for logs.
# This allocates 1Gi of such local storage.
Expand All @@ -28,22 +29,28 @@ class Resources(DataClassJSONMixin):
Also refer to the `K8s conventions. <https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/#resource-units-in-kubernetes>`__
"""

cpu: Optional[str] = None
mem: Optional[str] = None
gpu: Optional[str] = None
ephemeral_storage: Optional[str] = None
cpu: Optional[Union[str, int, float]] = None
mem: Optional[Union[str, int]] = None
gpu: Optional[Union[str, int]] = None
ephemeral_storage: Optional[Union[str, int]] = None

def __post_init__(self):
def _check_none_or_str(value):
def _check_cpu(value):
if value is None:
return
if not isinstance(value, str):
raise AssertionError(f"{value} should be a string")
if not isinstance(value, (str, int, float)):
raise AssertionError(f"{value} should be of type str or int or float")

_check_none_or_str(self.cpu)
_check_none_or_str(self.mem)
_check_none_or_str(self.gpu)
_check_none_or_str(self.ephemeral_storage)
def _check_others(value):
if value is None:
return
if not isinstance(value, (str, int)):
raise AssertionError(f"{value} should be of type str or int")

_check_cpu(self.cpu)
_check_others(self.mem)
_check_others(self.gpu)
_check_others(self.ephemeral_storage)


@dataclass
Expand All @@ -59,13 +66,15 @@ class ResourceSpec(DataClassJSONMixin):
def _convert_resources_to_resource_entries(resources: Resources) -> List[_ResourceEntry]: # type: ignore
resource_entries = []
if resources.cpu is not None:
resource_entries.append(_ResourceEntry(name=_ResourceName.CPU, value=resources.cpu))
resource_entries.append(_ResourceEntry(name=_ResourceName.CPU, value=str(resources.cpu)))
if resources.mem is not None:
resource_entries.append(_ResourceEntry(name=_ResourceName.MEMORY, value=resources.mem))
resource_entries.append(_ResourceEntry(name=_ResourceName.MEMORY, value=str(resources.mem)))
if resources.gpu is not None:
resource_entries.append(_ResourceEntry(name=_ResourceName.GPU, value=resources.gpu))
resource_entries.append(_ResourceEntry(name=_ResourceName.GPU, value=str(resources.gpu)))
if resources.ephemeral_storage is not None:
resource_entries.append(_ResourceEntry(name=_ResourceName.EPHEMERAL_STORAGE, value=resources.ephemeral_storage))
resource_entries.append(
_ResourceEntry(name=_ResourceName.EPHEMERAL_STORAGE, value=str(resources.ephemeral_storage))
)
return resource_entries


Expand Down
18 changes: 9 additions & 9 deletions flytekit/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from functools import wraps
from hashlib import sha224 as _sha224
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast

from flyteidl.core import tasks_pb2 as _core_task

Expand Down Expand Up @@ -62,14 +62,14 @@ def _get_container_definition(
command: List[str],
args: Optional[List[str]] = None,
data_loading_config: Optional["task_models.DataLoadingConfig"] = None,
ephemeral_storage_request: Optional[str] = None,
cpu_request: Optional[str] = None,
gpu_request: Optional[str] = None,
memory_request: Optional[str] = None,
ephemeral_storage_limit: Optional[str] = None,
cpu_limit: Optional[str] = None,
gpu_limit: Optional[str] = None,
memory_limit: Optional[str] = None,
ephemeral_storage_request: Optional[Union[str, int]] = None,
cpu_request: Optional[Union[str, int, float]] = None,
gpu_request: Optional[Union[str, int]] = None,
memory_request: Optional[Union[str, int]] = None,
ephemeral_storage_limit: Optional[Union[str, int]] = None,
cpu_limit: Optional[Union[str, int, float]] = None,
gpu_limit: Optional[Union[str, int]] = None,
memory_limit: Optional[Union[str, int]] = None,
environment: Optional[Dict[str, str]] = None,
) -> "task_models.Container":
ephemeral_storage_limit = ephemeral_storage_limit
Expand Down
8 changes: 4 additions & 4 deletions tests/flytekit/unit/core/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ def test_convert_limits(resource_dict: Dict[str, str], expected_resource_name: _

def test_incorrect_type_resources():
with pytest.raises(AssertionError):
Resources(cpu=1) # type: ignore
Resources(cpu=bytes(1)) # type: ignore
with pytest.raises(AssertionError):
Resources(mem=1) # type: ignore
Resources(mem=0.1) # type: ignore
with pytest.raises(AssertionError):
Resources(gpu=1) # type: ignore
Resources(gpu=0.1) # type: ignore
with pytest.raises(AssertionError):
Resources(ephemeral_storage=1) # type: ignore
Resources(ephemeral_storage=0.1) # type: ignore


def test_resources_serialization():
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1983,7 +1983,7 @@ def t1(a: int) -> int:

@workflow
def my_wf(a: int) -> int:
return t1(a=a).with_overrides(requests=Resources(cpu=1)) # type: ignore
return t1(a=a).with_overrides(requests=Resources(cpu=1, mem=1.1)) # type: ignore

with pytest.raises(AssertionError):
my_wf(a=1)
Expand Down

0 comments on commit b26e298

Please sign in to comment.