Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix extract_task_module #1829

Merged
merged 3 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions flytekit/core/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,12 @@
"""

if isinstance(f, TrackedInstance):
if f.instantiated_in:
if hasattr(f, "task_function"):
mod, mod_name, name = _task_module_from_callable(f.task_function)

Check warning on line 302 in flytekit/core/tracker.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/tracker.py#L302

Added line #L302 was not covered by tests
elif f.instantiated_in:
mod = importlib.import_module(f.instantiated_in)
mod_name = mod.__name__
name = f.lhs
elif hasattr(f, "task_function"):
mod, mod_name, name = _task_module_from_callable(f.task_function)
else:
mod, mod_name, name = _task_module_from_callable(f)

Expand Down
36 changes: 36 additions & 0 deletions plugins/flytekit-papermill/tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import datetime
import os
import shutil
import tempfile
import typing
from unittest import mock

import pandas as pd
from click.testing import CliRunner
from flytekitplugins.papermill import NotebookTask
from flytekitplugins.pod import Pod
from kubernetes.client import V1Container, V1PodSpec

import flytekit
from flytekit import StructuredDataset, kwtypes, map_task, task, workflow
from flytekit.clients.friendly import SynchronousFlyteClient
from flytekit.clis.sdk_in_container import pyflyte
from flytekit.configuration import Image, ImageConfig
from flytekit.core import context_manager
from flytekit.remote import FlyteRemote
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile, PythonNotebook

Expand Down Expand Up @@ -189,3 +196,32 @@ def wf(a: float) -> typing.List[float]:
return map_task(nb_sub_task)(a=[a, a])

assert wf(a=3.14) == [9.8596, 9.8596]


@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote)
@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient)
def test_register_notebook_task(mock_client, mock_remote):
mock_remote._client = mock_client
mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash"
mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url"
runner = CliRunner()
context_manager.FlyteEntities.entities.clear()
notebook_task = """
from flytekitplugins.papermill import NotebookTask

nb_simple = NotebookTask(
name="test",
notebook_path="./core/notebook.ipython",
)
"""
with runner.isolated_filesystem():
os.makedirs("core", exist_ok=True)
with open(os.path.join("core", "notebook.ipython"), "w") as f:
f.write("notebook.ipython")
f.close()
with open(os.path.join("core", "notebook_task.py"), "w") as f:
f.write(notebook_task)
f.close()
result = runner.invoke(pyflyte.main, ["register", "core"])
assert "Successfully registered 2 entities" in result.output
shutil.rmtree("core")
35 changes: 35 additions & 0 deletions plugins/flytekit-sqlalchemy/tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,21 @@
import sqlite3
import tempfile
from typing import Iterator
from unittest import mock

import pandas
import pytest
from click.testing import CliRunner
from flytekitplugins.sqlalchemy import SQLAlchemyConfig, SQLAlchemyTask
from flytekitplugins.sqlalchemy.task import SQLAlchemyTaskExecutor

from flytekit import kwtypes, task, workflow
from flytekit.clients.friendly import SynchronousFlyteClient
from flytekit.clis.sdk_in_container import pyflyte
from flytekit.core import context_manager
from flytekit.core.context_manager import SecretsManager
from flytekit.models.security import Secret
from flytekit.remote import FlyteRemote
from flytekit.types.schema import FlyteSchema

tk = SQLAlchemyTask(
Expand Down Expand Up @@ -197,3 +203,32 @@ def test_task_serialization_deserialization_with_secret(sql_server):
r = executor.execute_from_model(tt)

assert r.iat[0, 0] == 1


@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote)
@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient)
def test_register_sql_task(mock_client, mock_remote):
mock_remote._client = mock_client
mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash"
mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url"
runner = CliRunner()
context_manager.FlyteEntities.entities.clear()
sql_task = """
from flytekitplugins.sqlalchemy import SQLAlchemyConfig, SQLAlchemyTask

tk = SQLAlchemyTask(
"test",
query_template="select * from tracks",
task_config=SQLAlchemyConfig(
uri="sqlite://",
),
)
"""
with runner.isolated_filesystem():
os.makedirs("core", exist_ok=True)
with open(os.path.join("core", "sql_task.py"), "w") as f:
f.write(sql_task)
f.close()
result = runner.invoke(pyflyte.main, ["register", "core"])
assert "Successfully registered 1 entities" in result.output
shutil.rmtree("core")
29 changes: 29 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ def my_workflow(x: int, y: int) -> int:
return sum(x=square(z=x), y=square(z=y))
"""

shell_task = """
from flytekit.extras.tasks.shell import ShellTask

t = ShellTask(
name="test",
script="echo 'Hello World'",
)
"""


@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote")
def test_saving_remote(mock_remote):
Expand Down Expand Up @@ -69,6 +78,26 @@ def test_register_with_no_output_dir_passed(mock_client, mock_remote):
shutil.rmtree("core1")


@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote)
@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient)
def test_register_shell_task(mock_client, mock_remote):
mock_remote._client = mock_client
mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash"
mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url"
runner = CliRunner()
context_manager.FlyteEntities.entities.clear()
with runner.isolated_filesystem():
out = subprocess.run(["git", "init"], capture_output=True)
assert out.returncode == 0
os.makedirs("core2", exist_ok=True)
with open(os.path.join("core2", "shell_task.py"), "w") as f:
f.write(shell_task)
f.close()
result = runner.invoke(pyflyte.main, ["register", "core2"])
assert "Successfully registered 2 entities" in result.output
shutil.rmtree("core2")


@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote)
@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient)
def test_non_fast_register(mock_client, mock_remote):
Expand Down
Loading