Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement: accept int value for Resources #2196

Merged
merged 13 commits into from
Apr 30, 2024
41 changes: 25 additions & 16 deletions flytekit/core/resources.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Union

from mashumaro.mixins.json import DataClassJSONMixin

Expand All @@ -15,6 +15,7 @@ class Resources(DataClassJSONMixin):

Resources(cpu="1", mem="2048") # This is 1 CPU and 2 KB of memory
Resources(cpu="100m", mem="2Gi") # This is 1/10th of a CPU and 2 gigabytes of memory
Resources(cpu=0.5, mem=1024) # This is 500m CPU and 1 KB of memory

# For Kubernetes-based tasks, pods use ephemeral local storage for scratch space, caching, and for logs.
# This allocates 1Gi of such local storage.
Expand All @@ -28,22 +29,28 @@ class Resources(DataClassJSONMixin):
Also refer to the `K8s conventions. <https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/#resource-units-in-kubernetes>`__
"""

cpu: Optional[str] = None
mem: Optional[str] = None
gpu: Optional[str] = None
ephemeral_storage: Optional[str] = None
cpu: Optional[Union[str, int, float]] = None
mem: Optional[Union[str, int]] = None
gpu: Optional[Union[str, int]] = None
ephemeral_storage: Optional[Union[str, int]] = None

def __post_init__(self):
def _check_none_or_str(value):
def _check_cpu(value):
if value is None:
return
if not isinstance(value, str):
raise AssertionError(f"{value} should be a string")
if not isinstance(value, (str, int, float)):
raise AssertionError(f"{value} should be of type str or int or float")

_check_none_or_str(self.cpu)
_check_none_or_str(self.mem)
_check_none_or_str(self.gpu)
_check_none_or_str(self.ephemeral_storage)
def _check_others(value):
if value is None:
return
if not isinstance(value, (str, int)):
raise AssertionError(f"{value} should be of type str or int")

_check_cpu(self.cpu)
_check_others(self.mem)
_check_others(self.gpu)
_check_others(self.ephemeral_storage)


@dataclass
Expand All @@ -59,13 +66,15 @@ class ResourceSpec(DataClassJSONMixin):
def _convert_resources_to_resource_entries(resources: Resources) -> List[_ResourceEntry]: # type: ignore
resource_entries = []
if resources.cpu is not None:
resource_entries.append(_ResourceEntry(name=_ResourceName.CPU, value=resources.cpu))
resource_entries.append(_ResourceEntry(name=_ResourceName.CPU, value=str(resources.cpu)))
if resources.mem is not None:
resource_entries.append(_ResourceEntry(name=_ResourceName.MEMORY, value=resources.mem))
resource_entries.append(_ResourceEntry(name=_ResourceName.MEMORY, value=str(resources.mem)))
if resources.gpu is not None:
resource_entries.append(_ResourceEntry(name=_ResourceName.GPU, value=resources.gpu))
resource_entries.append(_ResourceEntry(name=_ResourceName.GPU, value=str(resources.gpu)))
if resources.ephemeral_storage is not None:
resource_entries.append(_ResourceEntry(name=_ResourceName.EPHEMERAL_STORAGE, value=resources.ephemeral_storage))
resource_entries.append(
_ResourceEntry(name=_ResourceName.EPHEMERAL_STORAGE, value=str(resources.ephemeral_storage))
)
return resource_entries


Expand Down
18 changes: 9 additions & 9 deletions flytekit/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from functools import wraps
from hashlib import sha224 as _sha224
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast

from flyteidl.core import tasks_pb2 as _core_task

Expand Down Expand Up @@ -62,14 +62,14 @@ def _get_container_definition(
command: List[str],
args: Optional[List[str]] = None,
data_loading_config: Optional["task_models.DataLoadingConfig"] = None,
ephemeral_storage_request: Optional[str] = None,
cpu_request: Optional[str] = None,
gpu_request: Optional[str] = None,
memory_request: Optional[str] = None,
ephemeral_storage_limit: Optional[str] = None,
cpu_limit: Optional[str] = None,
gpu_limit: Optional[str] = None,
memory_limit: Optional[str] = None,
ephemeral_storage_request: Optional[Union[str, int]] = None,
cpu_request: Optional[Union[str, int, float]] = None,
gpu_request: Optional[Union[str, int]] = None,
memory_request: Optional[Union[str, int]] = None,
ephemeral_storage_limit: Optional[Union[str, int]] = None,
cpu_limit: Optional[Union[str, int, float]] = None,
gpu_limit: Optional[Union[str, int]] = None,
memory_limit: Optional[Union[str, int]] = None,
environment: Optional[Dict[str, str]] = None,
) -> "task_models.Container":
ephemeral_storage_limit = ephemeral_storage_limit
Expand Down
8 changes: 4 additions & 4 deletions tests/flytekit/unit/core/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ def test_convert_limits(resource_dict: Dict[str, str], expected_resource_name: _

def test_incorrect_type_resources():
with pytest.raises(AssertionError):
Resources(cpu=1) # type: ignore
Resources(cpu=bytes(1)) # type: ignore
with pytest.raises(AssertionError):
Resources(mem=1) # type: ignore
Resources(mem=0.1) # type: ignore
with pytest.raises(AssertionError):
Resources(gpu=1) # type: ignore
Resources(gpu=0.1) # type: ignore
with pytest.raises(AssertionError):
Resources(ephemeral_storage=1) # type: ignore
Resources(ephemeral_storage=0.1) # type: ignore


def test_resources_serialization():
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1983,7 +1983,7 @@ def t1(a: int) -> int:

@workflow
def my_wf(a: int) -> int:
return t1(a=a).with_overrides(requests=Resources(cpu=1)) # type: ignore
return t1(a=a).with_overrides(requests=Resources(cpu=1, mem=1.1)) # type: ignore

with pytest.raises(AssertionError):
my_wf(a=1)
Expand Down
Loading