Skip to content

Commit

Permalink
Async file sensor (#1790)
Browse files Browse the repository at this point in the history
---------
Signed-off-by: Kevin Su <pingsutw@apache.org>
  • Loading branch information
pingsutw committed Aug 23, 2023
1 parent 2164d4e commit 0a772f3
Show file tree
Hide file tree
Showing 11 changed files with 185 additions and 3 deletions.
1 change: 1 addition & 0 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@
from flytekit.models.documentation import Description, Documentation, SourceCode
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.sensor.sensor_engine import SensorEngine
from flytekit.types import directory, file, iterator
from flytekit.types.structured.structured_dataset import (
StructuredDataset,
Expand Down
2 changes: 1 addition & 1 deletion flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def execute(self, **kwargs) -> typing.Any:
agent = AgentRegistry.get_agent(dummy_context, cp_entity.template.type)

if agent is None:
raise Exception("Cannot run the task locally, please mock.")
raise Exception("Cannot find the agent for the task")
literals = {}
ctx = FlyteContext.current_context()
for k, v in kwargs.items():
Expand Down
3 changes: 3 additions & 0 deletions flytekit/sensor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base_sensor import BaseSensor
from .file_sensor import FileSensor
from .sensor_engine import SensorEngine
66 changes: 66 additions & 0 deletions flytekit/sensor/base_sensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import collections
import inspect
from abc import abstractmethod
from typing import Any, Dict, Optional, TypeVar

import jsonpickle
from typing_extensions import get_type_hints

from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask
from flytekit.core.interface import Interface
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin

T = TypeVar("T")
SENSOR_MODULE = "sensor_module"
SENSOR_NAME = "sensor_name"
SENSOR_CONFIG_PKL = "sensor_config_pkl"
INPUTS = "inputs"


class BaseSensor(AsyncAgentExecutorMixin, PythonTask):
"""
Base class for all sensors. Sensors are tasks that are designed to run forever, and periodically check for some
condition to be met. When the condition is met, the sensor will complete. Sensors are designed to be run by the
sensor agent, and not by the Flyte engine.
"""

def __init__(
self,
name: str,
sensor_config: Optional[T] = None,
task_type: str = "sensor",
**kwargs,
):
type_hints = get_type_hints(self.poke, include_extras=True)
signature = inspect.signature(self.poke)
inputs = collections.OrderedDict()
for k, v in signature.parameters.items(): # type: ignore
annotation = type_hints.get(k, None)
inputs[k] = annotation

super().__init__(
task_type=task_type,
name=name,
task_config=None,
interface=Interface(inputs=inputs),
**kwargs,
)
self._sensor_config = sensor_config

@abstractmethod
async def poke(self, **kwargs) -> bool:
"""
This method should be overridden by the user to implement the actual sensor logic. This method should return
``True`` if the sensor condition is met, else ``False``.
"""
raise NotImplementedError

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
cfg = {
SENSOR_MODULE: type(self).__module__,
SENSOR_NAME: type(self).__name__,
}
if self._sensor_config is not None:
cfg[SENSOR_CONFIG_PKL] = jsonpickle.encode(self._sensor_config)
return cfg
18 changes: 18 additions & 0 deletions flytekit/sensor/file_sensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Optional, TypeVar

from flytekit import FlyteContextManager
from flytekit.sensor.base_sensor import BaseSensor

T = TypeVar("T")


class FileSensor(BaseSensor):
def __init__(self, name: str, config: Optional[T] = None, **kwargs):
super().__init__(name=name, sensor_config=config, **kwargs)

async def poke(self, path: str) -> bool:
file_access = FlyteContextManager.current_context().file_access
fs = file_access.get_filesystem_for_path(path, asynchronous=True)
if file_access.is_remote(path):
return await fs._exists(path)
return fs.exists(path)
62 changes: 62 additions & 0 deletions flytekit/sensor/sensor_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import importlib
import typing
from typing import Optional

import cloudpickle
import grpc
import jsonpickle
from flyteidl.admin.agent_pb2 import (
RUNNING,
SUCCEEDED,
CreateTaskResponse,
DeleteTaskResponse,
GetTaskResponse,
Resource,
)

from flytekit import FlyteContextManager
from flytekit.core.type_engine import TypeEngine
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
from flytekit.sensor.base_sensor import INPUTS, SENSOR_CONFIG_PKL, SENSOR_MODULE, SENSOR_NAME

T = typing.TypeVar("T")


class SensorEngine(AgentBase):
def __init__(self):
super().__init__(task_type="sensor", asynchronous=True)

async def async_create(
self,
context: grpc.ServicerContext,
output_prefix: str,
task_template: TaskTemplate,
inputs: Optional[LiteralMap] = None,
) -> CreateTaskResponse:
python_interface_inputs = {
name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items()
}
ctx = FlyteContextManager.current_context()
if inputs:
native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs)
task_template.custom[INPUTS] = native_inputs
return CreateTaskResponse(resource_meta=cloudpickle.dumps(task_template.custom))

async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse:
meta = cloudpickle.loads(resource_meta)

sensor_module = importlib.import_module(name=meta[SENSOR_MODULE])
sensor_def = getattr(sensor_module, meta[SENSOR_NAME])
sensor_config = jsonpickle.decode(meta[SENSOR_CONFIG_PKL]) if meta.get(SENSOR_CONFIG_PKL) else None

inputs = meta.get(INPUTS, {})
cur_state = SUCCEEDED if await sensor_def("sensor", config=sensor_config).poke(**inputs) else RUNNING
return GetTaskResponse(resource=Resource(state=cur_state, outputs=None))

async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse:
return DeleteTaskResponse()


AgentRegistry.register(SensorEngine())
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Metadata:

class BigQueryAgent(AgentBase):
def __init__(self):
super().__init__(task_type="bigquery_query_job_task")
super().__init__(task_type="bigquery_query_job_task", asynchronous=False)

def create(
self,
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
"kubernetes>=12.0.1",
"rich",
"rich_click",
"jsonpickle",
],
extras_require=extras_require,
scripts=[
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self, **kwargs):
t.execute()

t._task_type = "non-exist-type"
with pytest.raises(Exception, match="Cannot run the task locally"):
with pytest.raises(Exception, match="Cannot find the agent for the task"):
t.execute()


Expand Down
Empty file.
31 changes: 31 additions & 0 deletions tests/flytekit/unit/sensor/test_file_sensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import tempfile

from flytekit import task, workflow
from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.sensor.file_sensor import FileSensor
from tests.flytekit.unit.test_translator import default_img


def test_sensor_task():
sensor = FileSensor(name="test_sensor")
assert sensor.task_type == "sensor"
settings = SerializationSettings(
project="project",
domain="domain",
version="version",
env={"FOO": "baz"},
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)
assert sensor.get_custom(settings) == {"sensor_module": "flytekit.sensor.file_sensor", "sensor_name": "FileSensor"}
tmp_file = tempfile.NamedTemporaryFile()

@task()
def t1():
print("flyte")

@workflow
def wf():
sensor(tmp_file.name) >> t1()

if __name__ == "__main__":
wf()

0 comments on commit 0a772f3

Please sign in to comment.