-
Notifications
You must be signed in to change notification settings - Fork 260
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
--------- Signed-off-by: Kevin Su <pingsutw@apache.org>
- Loading branch information
Showing
11 changed files
with
185 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base_sensor import BaseSensor | ||
from .file_sensor import FileSensor | ||
from .sensor_engine import SensorEngine |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import collections | ||
import inspect | ||
from abc import abstractmethod | ||
from typing import Any, Dict, Optional, TypeVar | ||
|
||
import jsonpickle | ||
from typing_extensions import get_type_hints | ||
|
||
from flytekit.configuration import SerializationSettings | ||
from flytekit.core.base_task import PythonTask | ||
from flytekit.core.interface import Interface | ||
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin | ||
|
||
T = TypeVar("T") | ||
SENSOR_MODULE = "sensor_module" | ||
SENSOR_NAME = "sensor_name" | ||
SENSOR_CONFIG_PKL = "sensor_config_pkl" | ||
INPUTS = "inputs" | ||
|
||
|
||
class BaseSensor(AsyncAgentExecutorMixin, PythonTask): | ||
""" | ||
Base class for all sensors. Sensors are tasks that are designed to run forever, and periodically check for some | ||
condition to be met. When the condition is met, the sensor will complete. Sensors are designed to be run by the | ||
sensor agent, and not by the Flyte engine. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
sensor_config: Optional[T] = None, | ||
task_type: str = "sensor", | ||
**kwargs, | ||
): | ||
type_hints = get_type_hints(self.poke, include_extras=True) | ||
signature = inspect.signature(self.poke) | ||
inputs = collections.OrderedDict() | ||
for k, v in signature.parameters.items(): # type: ignore | ||
annotation = type_hints.get(k, None) | ||
inputs[k] = annotation | ||
|
||
super().__init__( | ||
task_type=task_type, | ||
name=name, | ||
task_config=None, | ||
interface=Interface(inputs=inputs), | ||
**kwargs, | ||
) | ||
self._sensor_config = sensor_config | ||
|
||
@abstractmethod | ||
async def poke(self, **kwargs) -> bool: | ||
""" | ||
This method should be overridden by the user to implement the actual sensor logic. This method should return | ||
``True`` if the sensor condition is met, else ``False``. | ||
""" | ||
raise NotImplementedError | ||
|
||
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: | ||
cfg = { | ||
SENSOR_MODULE: type(self).__module__, | ||
SENSOR_NAME: type(self).__name__, | ||
} | ||
if self._sensor_config is not None: | ||
cfg[SENSOR_CONFIG_PKL] = jsonpickle.encode(self._sensor_config) | ||
return cfg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from typing import Optional, TypeVar | ||
|
||
from flytekit import FlyteContextManager | ||
from flytekit.sensor.base_sensor import BaseSensor | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
class FileSensor(BaseSensor): | ||
def __init__(self, name: str, config: Optional[T] = None, **kwargs): | ||
super().__init__(name=name, sensor_config=config, **kwargs) | ||
|
||
async def poke(self, path: str) -> bool: | ||
file_access = FlyteContextManager.current_context().file_access | ||
fs = file_access.get_filesystem_for_path(path, asynchronous=True) | ||
if file_access.is_remote(path): | ||
return await fs._exists(path) | ||
return fs.exists(path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import importlib | ||
import typing | ||
from typing import Optional | ||
|
||
import cloudpickle | ||
import grpc | ||
import jsonpickle | ||
from flyteidl.admin.agent_pb2 import ( | ||
RUNNING, | ||
SUCCEEDED, | ||
CreateTaskResponse, | ||
DeleteTaskResponse, | ||
GetTaskResponse, | ||
Resource, | ||
) | ||
|
||
from flytekit import FlyteContextManager | ||
from flytekit.core.type_engine import TypeEngine | ||
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry | ||
from flytekit.models.literals import LiteralMap | ||
from flytekit.models.task import TaskTemplate | ||
from flytekit.sensor.base_sensor import INPUTS, SENSOR_CONFIG_PKL, SENSOR_MODULE, SENSOR_NAME | ||
|
||
T = typing.TypeVar("T") | ||
|
||
|
||
class SensorEngine(AgentBase): | ||
def __init__(self): | ||
super().__init__(task_type="sensor", asynchronous=True) | ||
|
||
async def async_create( | ||
self, | ||
context: grpc.ServicerContext, | ||
output_prefix: str, | ||
task_template: TaskTemplate, | ||
inputs: Optional[LiteralMap] = None, | ||
) -> CreateTaskResponse: | ||
python_interface_inputs = { | ||
name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() | ||
} | ||
ctx = FlyteContextManager.current_context() | ||
if inputs: | ||
native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) | ||
task_template.custom[INPUTS] = native_inputs | ||
return CreateTaskResponse(resource_meta=cloudpickle.dumps(task_template.custom)) | ||
|
||
async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: | ||
meta = cloudpickle.loads(resource_meta) | ||
|
||
sensor_module = importlib.import_module(name=meta[SENSOR_MODULE]) | ||
sensor_def = getattr(sensor_module, meta[SENSOR_NAME]) | ||
sensor_config = jsonpickle.decode(meta[SENSOR_CONFIG_PKL]) if meta.get(SENSOR_CONFIG_PKL) else None | ||
|
||
inputs = meta.get(INPUTS, {}) | ||
cur_state = SUCCEEDED if await sensor_def("sensor", config=sensor_config).poke(**inputs) else RUNNING | ||
return GetTaskResponse(resource=Resource(state=cur_state, outputs=None)) | ||
|
||
async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: | ||
return DeleteTaskResponse() | ||
|
||
|
||
AgentRegistry.register(SensorEngine()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import tempfile | ||
|
||
from flytekit import task, workflow | ||
from flytekit.configuration import ImageConfig, SerializationSettings | ||
from flytekit.sensor.file_sensor import FileSensor | ||
from tests.flytekit.unit.test_translator import default_img | ||
|
||
|
||
def test_sensor_task(): | ||
sensor = FileSensor(name="test_sensor") | ||
assert sensor.task_type == "sensor" | ||
settings = SerializationSettings( | ||
project="project", | ||
domain="domain", | ||
version="version", | ||
env={"FOO": "baz"}, | ||
image_config=ImageConfig(default_image=default_img, images=[default_img]), | ||
) | ||
assert sensor.get_custom(settings) == {"sensor_module": "flytekit.sensor.file_sensor", "sensor_name": "FileSensor"} | ||
tmp_file = tempfile.NamedTemporaryFile() | ||
|
||
@task() | ||
def t1(): | ||
print("flyte") | ||
|
||
@workflow | ||
def wf(): | ||
sensor(tmp_file.name) >> t1() | ||
|
||
if __name__ == "__main__": | ||
wf() |