Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mlflow registered models to assets #12

Merged
merged 1 commit into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/providers/client/gitlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,7 @@ def _adapt_graphql_project(project_data: GitlabGraphQL_Project) -> Project:
last_commit=last_commit,
files=files,
latest_release=release,
mlflow=None,
access_level=access_level,
)

Expand Down
12 changes: 12 additions & 0 deletions app/providers/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ class Release(BaseModel):
commit: str


class RegisteredModel(BaseModel):
name: str
latest_version: str
mlflow_uri: str


class MLflow(BaseModel):
tracking_uri: str
registered_models: list[RegisteredModel]


class ProjectReference(BaseModel):
id: int
name: str
Expand Down Expand Up @@ -76,6 +87,7 @@ class Project(ProjectPreview):
last_commit: str | None
files: list[str] | None
latest_release: Release | None
mlflow: MLflow | None
# 0 for no access
# 1 for read-only (visitor)
# 2 for modification allowed (contributor)
Expand Down
2 changes: 1 addition & 1 deletion app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@

# ____ MLflow ____ #

MLFLOW_TYPE: Literal["mlflow", "mlflow-sharinghub", "gitlab"] | None = conf(
MLFLOW_TYPE: Literal["mlflow", "mlflow-sharinghub", "gitlab"] = conf(
"mlflow.type", "MLFLOW_TYPE", default="mlflow-sharinghub", cast=str
)
MLFLOW_URL: str | None = conf("mlflow.url", "MLFLOW_URL", cast=str)
Expand Down
20 changes: 19 additions & 1 deletion app/stac/api/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from app.utils import geo
from app.utils import markdown as md
from app.utils.http import is_local, url_for
from app.utils.http import is_local, slugify, url_for

from .category import Category, FeatureVal
from .search import STACPagination
Expand Down Expand Up @@ -534,6 +534,24 @@ def build_stac_item(
},
)

if project.mlflow:
stac_links.append(
{
"rel": "mlflow",
"title": "Tracking URI",
"href": project.mlflow.tracking_uri,
},
)
for rm in project.mlflow.registered_models:
model_name = rm.name.removesuffix(f"({project.id})").rstrip()
rm_asset = {
"href": rm.mlflow_uri,
"title": f"{model_name} v{rm.latest_version}",
"description": rm.mlflow_uri,
"roles": ["mlflow"],
}
stac_assets[f"model:/{slugify(model_name)}"] = rm_asset

if doi := extensions_properties.get("sci:doi"):
stac_links.append({"rel": "cite-as", "href": f"{DOI_URL}{doi}"})

Expand Down
72 changes: 70 additions & 2 deletions app/stac/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,25 @@
import json
import logging
import time
from typing import Literal

from fastapi import HTTPException, Request
from fastapi.routing import APIRouter

from app.auth import GitlabTokenDep
from app.providers.client import CursorPagination, GitlabClient
from app.providers.schemas import Project
from app.settings import ENABLE_CACHE, GITLAB_URL
from app.providers.schemas import MLflow, Project, RegisteredModel
from app.settings import ENABLE_CACHE, GITLAB_URL, MLFLOW_TYPE, MLFLOW_URL
from app.stac.api.category import (
Category,
CategoryFromCollectionIdDep,
FeatureVal,
get_categories,
get_category,
)
from app.utils import geo
from app.utils.cache import cache
from app.utils.http import AiohttpClient, clean_url

from .api.build import (
build_features_collection,
Expand Down Expand Up @@ -190,6 +193,12 @@ async def stac_collection_feature(
return cached_stac["stac"]

await _resolve_license(project, gitlab_client)
await _collect_registered_models(
project,
mlflow_type=MLFLOW_TYPE,
mlflow_url=MLFLOW_URL,
auth_token=token.value,
)

project_stac = build_stac_item(
project=project,
Expand Down Expand Up @@ -401,6 +410,15 @@ async def _stac_search( # noqa: C901
count = len(projects)
await asyncio.gather(
*(_resolve_license(p, gitlab_client) for p in projects),
*(
_collect_registered_models(
p,
mlflow_type=MLFLOW_TYPE,
mlflow_url=MLFLOW_URL,
auth_token=token.value,
)
for p in projects
),
)
features = [
build_stac_item(p, category, request=request, token=token)
Expand Down Expand Up @@ -455,3 +473,53 @@ async def _resolve_license(project: Project, client: GitlabClient) -> None:

if license_ and license_ != nolicense:
project.license = license_


async def _collect_registered_models(
project: Project,
mlflow_type: Literal["mlflow", "mlflow-sharinghub", "gitlab"],
mlflow_url: str | None,
auth_token: str,
) -> None:
if mlflow_url and any(
c.features.get("mlflow") == FeatureVal.ENABLE for c in project.categories
):
mlflow_url = clean_url(mlflow_url)
registered_models = []
match mlflow_type:
case "mlflow-sharinghub":
tracking_uri = f"{mlflow_url}{project.path}/tracking"
mlflow_api_url = tracking_uri + "/api/2.0/mlflow"
registered_models.extend(
await _get_registered_models(mlflow_api_url, auth_token=auth_token)
)
project.mlflow = MLflow(
tracking_uri=tracking_uri,
registered_models=registered_models,
)


async def _get_registered_models(
mlflow_api_url: str, auth_token: str
) -> list[RegisteredModel]:
registered_models = []
req_url = mlflow_api_url + "/registered-models/search"
async with AiohttpClient() as client:
headers = {"Authorization": f"Bearer {auth_token}"}
response = await client.get(req_url, headers=headers)
if response.ok:
data = await response.json()
all_registered_models = data.get("registered_models", [])
for rm_data in all_registered_models:
if latest_versions := rm_data.get("latest_versions", []):
latest_version = latest_versions[0]
model_name = rm_data["name"]
model_version = latest_version["version"]
registered_models.append(
RegisteredModel(
name=model_name,
latest_version=model_version,
mlflow_uri=f"models:/{model_name}/{model_version}",
)
)
return registered_models
Loading