Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Databricks Agent #1797

Merged
merged 31 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
540eb5d
Merge branch 'master' of https://github.com/Future-Outlier/flytekit
Aug 12, 2023
984d44a
Merge branch 'master' of https://github.com/Future-Outlier/flytekit
Aug 16, 2023
d1491f0
databricks agent v1
Aug 16, 2023
5e5492d
revision for docker image
Aug 16, 2023
003fa8b
rerun make lint and make fmt
Aug 17, 2023
7c52cba
add PERMANENT_FAILURE
Aug 17, 2023
442be79
REST API for databricks agent v1, async get function is unsure
Aug 18, 2023
6e51c36
add aiohttp in setup.py
Aug 18, 2023
b9c98f1
databricks agent with getting token by secret
Aug 20, 2023
23c5501
revise the code and delete the databricks_token member
Aug 20, 2023
7a98b19
remove databricks_token member
Aug 22, 2023
fa2059d
add databricks agent test
Aug 22, 2023
8d4c77d
Merge branch 'flyteorg:master' into databricks-python-sdk-agent
Future-Outlier Aug 22, 2023
6b3e745
revise by kevin
Aug 22, 2023
a5e0412
Merge branch 'databricks-python-sdk-agent' of https://github.com/Futu…
Aug 22, 2023
8f9dcda
edit get function
Aug 22, 2023
bf857ac
add spark plugin_requires in setup.py
Aug 22, 2023
9c20c4b
Refactor and Revise test_agent.py after kevin's refactor
Aug 22, 2023
77d2d70
remove databricks endpoint member
Aug 23, 2023
ab01850
fix databricks test_agent.py args error
Aug 23, 2023
232d19d
Databricks Agent With Agent Server Only
Aug 30, 2023
ef2b2f7
Fix dev-requirements.in lint error
Aug 30, 2023
606c09f
Merge branch 'flyteorg:master' into databricks-python-sdk-agent
Future-Outlier Aug 30, 2023
d65525f
error handle
pingsutw Sep 4, 2023
550973d
lint
pingsutw Sep 4, 2023
008cf7b
nit
pingsutw Sep 4, 2023
92e3310
Update from kevin's revision
Sep 4, 2023
502ae3f
fix the mocked header in test
Sep 4, 2023
1dbfd38
Merge branch 'master' into databricks-python-sdk-agent
Sep 8, 2023
5b091b9
update spark agent test
Sep 8, 2023
f074740
rename token to access_token
Sep 9, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions flytekit/extend/backend/agent_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon
logger.error(f"failed to run sync create with error {e}")
raise
except Exception as e:
logger.error(f"failed to create task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to create task with error {e}")

Expand All @@ -60,7 +59,6 @@ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext)
logger.error(f"failed to run sync get with error {e}")
raise
except Exception as e:
logger.error(f"failed to get task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to get task with error {e}")

Expand All @@ -80,6 +78,5 @@ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerCon
logger.error(f"failed to run sync delete with error {e}")
raise
except Exception as e:
logger.error(f"failed to delete task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to delete task with error {e}")
2 changes: 2 additions & 0 deletions plugins/flytekit-spark/dev-requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
aioresponses
pytest-asyncio
44 changes: 44 additions & 0 deletions plugins/flytekit-spark/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#
# This file is autogenerated by pip-compile with Python 3.9
# by the following command:
#
# pip-compile dev-requirements.in
#
aiohttp==3.8.5
# via aioresponses
aioresponses==0.7.4
# via -r dev-requirements.in
aiosignal==1.3.1
# via aiohttp
async-timeout==4.0.3
# via aiohttp
attrs==23.1.0
# via aiohttp
charset-normalizer==3.2.0
# via aiohttp
exceptiongroup==1.1.3
# via pytest
frozenlist==1.4.0
# via
# aiohttp
# aiosignal
idna==3.4
# via yarl
iniconfig==2.0.0
# via pytest
multidict==6.0.4
# via
# aiohttp
# yarl
packaging==23.1
# via pytest
pluggy==1.3.0
# via pytest
pytest==7.4.0
# via pytest-asyncio
pytest-asyncio==0.21.1
# via -r dev-requirements.in
tomli==2.0.1
# via pytest
yarl==1.9.2
# via aiohttp
1 change: 1 addition & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from flytekit.configuration import internal as _internal

from .agent import DatabricksAgent
from .pyspark_transformers import PySparkPipelineModelTransformer
from .schema import SparkDataFrameSchemaReader, SparkDataFrameSchemaWriter, SparkDataFrameTransformer # noqa
from .sd_transformers import ParquetToSparkDecodingHandler, SparkToParquetEncodingHandler
Expand Down
99 changes: 99 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import json
import pickle
import typing
from dataclasses import dataclass
from typing import Optional

import aiohttp
import grpc
from flyteidl.admin.agent_pb2 import PENDING, CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource

import flytekit
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate


@dataclass
class Metadata:
databricks_instance: str
run_id: str


class DatabricksAgent(AgentBase):
def __init__(self):
super().__init__(task_type="spark")

async def async_create(
self,
context: grpc.ServicerContext,
output_prefix: str,
task_template: TaskTemplate,
inputs: Optional[LiteralMap] = None,
) -> CreateTaskResponse:
custom = task_template.custom
container = task_template.container
databricks_job = custom["databricksConf"]
if not databricks_job["new_cluster"].get("docker_image"):
databricks_job["new_cluster"]["docker_image"] = {"url": container.image}
if not databricks_job["new_cluster"].get("spark_conf"):
databricks_job["new_cluster"]["spark_conf"] = custom["sparkConf"]
databricks_job["spark_python_task"] = {
"python_file": custom["mainApplicationFile"],
"parameters": tuple(container.args),
}

databricks_instance = custom["databricksInstance"]
databricks_url = f"https://{databricks_instance}/api/2.0/jobs/runs/submit"
data = json.dumps(databricks_job)

async with aiohttp.ClientSession() as session:
async with session.post(databricks_url, headers=get_header(), data=data) as resp:
if resp.status != 200:
raise Exception(f"Failed to create databricks job with error: {resp.reason}")
response = await resp.json()

metadata = Metadata(
databricks_instance=databricks_instance,
run_id=str(response["run_id"]),
)
return CreateTaskResponse(resource_meta=pickle.dumps(metadata))

async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse:
metadata = pickle.loads(resource_meta)
databricks_instance = metadata.databricks_instance
databricks_url = f"https://{databricks_instance}/api/2.0/jobs/runs/get?run_id={metadata.run_id}"

async with aiohttp.ClientSession() as session:
async with session.get(databricks_url, headers=get_header()) as resp:
if resp.status != 200:
raise Exception(f"Failed to get databricks job {metadata.run_id} with error: {resp.reason}")
response = await resp.json()

cur_state = PENDING
if response.get("state") and response["state"].get("result_state"):
cur_state = convert_to_flyte_state(response["state"]["result_state"])

return GetTaskResponse(resource=Resource(state=cur_state))

async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse:
metadata = pickle.loads(resource_meta)

databricks_url = f"https://{metadata.databricks_instance}/api/2.0/jobs/runs/cancel"
data = json.dumps({"run_id": metadata.run_id})

async with aiohttp.ClientSession() as session:
async with session.post(databricks_url, headers=get_header(), data=data) as resp:
if resp.status != 200:
raise Exception(f"Failed to cancel databricks job {metadata.run_id} with error: {resp.reason}")
await resp.json()

return DeleteTaskResponse()


def get_header() -> typing.Dict[str, str]:
token = flytekit.current_context().secrets.get("databricks", "access_token")
return {"Authorization": f"Bearer {token}", "content-type": "application/json"}


AgentRegistry.register(DatabricksAgent())
10 changes: 6 additions & 4 deletions plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,18 @@ def __init__(
**kwargs,
):
self.sess: Optional[SparkSession] = None
self._default_executor_path: Optional[str] = task_config.executor_path
self._default_applications_path: Optional[str] = task_config.applications_path
self._default_executor_path: str = task_config.executor_path
self._default_applications_path: str = task_config.applications_path

if isinstance(container_image, ImageSpec):
if container_image.base_image is None:
img = f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}"
container_image.base_image = img
# default executor path and applications path in apache/spark-py:3.3.1
self._default_executor_path = "/usr/bin/python3"
self._default_applications_path = "local:///usr/local/bin/entrypoint.py"
self._default_executor_path = self._default_executor_path or "/usr/bin/python3"
self._default_applications_path = (
self._default_applications_path or "local:///usr/local/bin/entrypoint.py"
)
super(PysparkFunctionTask, self).__init__(
task_config=task_config,
task_type=self._SPARK_TASK_TYPE,
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-spark/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pyspark>=3.0.0"]
plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pyspark>=3.0.0", "aiohttp"]

__version__ = "0.0.0+develop"

Expand Down
136 changes: 136 additions & 0 deletions plugins/flytekit-spark/tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import pickle
from datetime import timedelta
from unittest import mock
from unittest.mock import MagicMock

import grpc
import pytest
from aioresponses import aioresponses
from flyteidl.admin.agent_pb2 import SUCCEEDED
from flytekitplugins.spark.agent import Metadata, get_header

from flytekit.extend.backend.base_agent import AgentRegistry
from flytekit.interfaces.cli_identifiers import Identifier
from flytekit.models import literals, task
from flytekit.models.core.identifier import ResourceType
from flytekit.models.task import Container, Resources, TaskTemplate


@pytest.mark.asyncio
async def test_databricks_agent():
ctx = MagicMock(spec=grpc.ServicerContext)
agent = AgentRegistry.get_agent("spark")

task_id = Identifier(
resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version"
)
task_metadata = task.TaskMetadata(
True,
task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"),
timedelta(days=1),
literals.RetryStrategy(3),
True,
"0.1.1b0",
"This is deprecated!",
True,
"A",
)
task_config = {
"sparkConf": {
"spark.driver.memory": "1000M",
"spark.executor.memory": "1000M",
"spark.executor.cores": "1",
"spark.executor.instances": "2",
"spark.driver.cores": "1",
},
"mainApplicationFile": "dbfs:/entrypoint.py",
"databricksConf": {
"run_name": "flytekit databricks plugin example",
"new_cluster": {
"spark_version": "12.2.x-scala2.12",
"node_type_id": "n2-highmem-4",
"num_workers": 1,
},
"timeout_seconds": 3600,
"max_retries": 1,
},
"databricksInstance": "test-account.cloud.databricks.com",
}
container = Container(
image="flyteorg/flytekit:databricks-0.18.0-py3.7",
command=[],
args=[
"pyflyte-fast-execute",
"--additional-distribution",
"s3://my-s3-bucket/flytesnacks/development/24UYJEF2HDZQN3SG4VAZSM4PLI======/script_mode.tar.gz",
"--dest-dir",
"/root",
"--",
"pyflyte-execute",
"--inputs",
"s3://my-s3-bucket",
"--output-prefix",
"s3://my-s3-bucket",
"--raw-output-data-prefix",
"s3://my-s3-bucket",
"--checkpoint-path",
"s3://my-s3-bucket",
"--prev-checkpoint",
"s3://my-s3-bucket",
"--resolver",
"flytekit.core.python_auto_container.default_task_resolver",
"--",
"task-module",
"spark_local_example",
"task-name",
"hello_spark",
],
resources=Resources(
requests=[],
limits=[],
),
env={},
config={},
)

dummy_template = TaskTemplate(
id=task_id,
custom=task_config,
metadata=task_metadata,
container=container,
interface=None,
type="spark",
)
mocked_token = "mocked_databricks_token"
mocked_context = mock.patch("flytekit.current_context", autospec=True).start()
mocked_context.return_value.secrets.get.return_value = mocked_token

metadata_bytes = pickle.dumps(
Metadata(
databricks_instance="test-account.cloud.databricks.com",
run_id="123",
)
)

mock_create_response = {"run_id": "123"}
mock_get_response = {"run_id": "123", "state": {"result_state": "SUCCESS"}}
mock_delete_response = {}
create_url = "https://test-account.cloud.databricks.com/api/2.0/jobs/runs/submit"
get_url = "https://test-account.cloud.databricks.com/api/2.0/jobs/runs/get?run_id=123"
delete_url = "https://test-account.cloud.databricks.com/api/2.0/jobs/runs/cancel"
with aioresponses() as mocked:
mocked.post(create_url, status=200, payload=mock_create_response)
res = await agent.async_create(ctx, "/tmp", dummy_template, None)
assert res.resource_meta == metadata_bytes

mocked.get(get_url, status=200, payload=mock_get_response)
res = await agent.async_get(ctx, metadata_bytes)
assert res.resource.state == SUCCEEDED
assert res.resource.outputs == literals.LiteralMap({}).to_flyte_idl()

mocked.post(delete_url, status=200, payload=mock_delete_response)
await agent.async_delete(ctx, metadata_bytes)

assert get_header() == {"Authorization": f"Bearer {mocked_token}", "content-type": "application/json"}

mock.patch.stopall()
Loading