Skip to content

Commit

Permalink
Merge pull request #12 from csgroup-oss/feat/mlflow-assets
Browse files Browse the repository at this point in the history
Add mlflow registered models to assets
  • Loading branch information
okoko-cs authored Sep 2, 2024
2 parents b3381dc + 6c8f2fc commit b547797
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 4 deletions.
1 change: 1 addition & 0 deletions app/providers/client/gitlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,7 @@ def _adapt_graphql_project(project_data: GitlabGraphQL_Project) -> Project:
last_commit=last_commit,
files=files,
latest_release=release,
mlflow=None,
access_level=access_level,
)

Expand Down
12 changes: 12 additions & 0 deletions app/providers/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ class Release(BaseModel):
commit: str


class RegisteredModel(BaseModel):
name: str
latest_version: str
mlflow_uri: str


class MLflow(BaseModel):
tracking_uri: str
registered_models: list[RegisteredModel]


class ProjectReference(BaseModel):
id: int
name: str
Expand Down Expand Up @@ -76,6 +87,7 @@ class Project(ProjectPreview):
last_commit: str | None
files: list[str] | None
latest_release: Release | None
mlflow: MLflow | None
# 0 for no access
# 1 for read-only (visitor)
# 2 for modification allowed (contributor)
Expand Down
2 changes: 1 addition & 1 deletion app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@

# ____ MLflow ____ #

MLFLOW_TYPE: Literal["mlflow", "mlflow-sharinghub", "gitlab"] | None = conf(
MLFLOW_TYPE: Literal["mlflow", "mlflow-sharinghub", "gitlab"] = conf(
"mlflow.type", "MLFLOW_TYPE", default="mlflow-sharinghub", cast=str
)
MLFLOW_URL: str | None = conf("mlflow.url", "MLFLOW_URL", cast=str)
Expand Down
20 changes: 19 additions & 1 deletion app/stac/api/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from app.utils import geo
from app.utils import markdown as md
from app.utils.http import is_local, url_for
from app.utils.http import is_local, slugify, url_for

from .category import Category, FeatureVal
from .search import STACPagination
Expand Down Expand Up @@ -534,6 +534,24 @@ def build_stac_item(
},
)

if project.mlflow:
stac_links.append(
{
"rel": "mlflow",
"title": "Tracking URI",
"href": project.mlflow.tracking_uri,
},
)
for rm in project.mlflow.registered_models:
model_name = rm.name.removesuffix(f"({project.id})").rstrip()
rm_asset = {
"href": rm.mlflow_uri,
"title": f"{model_name} v{rm.latest_version}",
"description": rm.mlflow_uri,
"roles": ["mlflow"],
}
stac_assets[f"model:/{slugify(model_name)}"] = rm_asset

if doi := extensions_properties.get("sci:doi"):
stac_links.append({"rel": "cite-as", "href": f"{DOI_URL}{doi}"})

Expand Down
72 changes: 70 additions & 2 deletions app/stac/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,25 @@
import json
import logging
import time
from typing import Literal

from fastapi import HTTPException, Request
from fastapi.routing import APIRouter

from app.auth import GitlabTokenDep
from app.providers.client import CursorPagination, GitlabClient
from app.providers.schemas import Project
from app.settings import ENABLE_CACHE, GITLAB_URL
from app.providers.schemas import MLflow, Project, RegisteredModel
from app.settings import ENABLE_CACHE, GITLAB_URL, MLFLOW_TYPE, MLFLOW_URL
from app.stac.api.category import (
Category,
CategoryFromCollectionIdDep,
FeatureVal,
get_categories,
get_category,
)
from app.utils import geo
from app.utils.cache import cache
from app.utils.http import AiohttpClient, clean_url

from .api.build import (
build_features_collection,
Expand Down Expand Up @@ -190,6 +193,12 @@ async def stac_collection_feature(
return cached_stac["stac"]

await _resolve_license(project, gitlab_client)
await _collect_registered_models(
project,
mlflow_type=MLFLOW_TYPE,
mlflow_url=MLFLOW_URL,
auth_token=token.value,
)

project_stac = build_stac_item(
project=project,
Expand Down Expand Up @@ -401,6 +410,15 @@ async def _stac_search( # noqa: C901
count = len(projects)
await asyncio.gather(
*(_resolve_license(p, gitlab_client) for p in projects),
*(
_collect_registered_models(
p,
mlflow_type=MLFLOW_TYPE,
mlflow_url=MLFLOW_URL,
auth_token=token.value,
)
for p in projects
),
)
features = [
build_stac_item(p, category, request=request, token=token)
Expand Down Expand Up @@ -455,3 +473,53 @@ async def _resolve_license(project: Project, client: GitlabClient) -> None:

if license_ and license_ != nolicense:
project.license = license_


async def _collect_registered_models(
project: Project,
mlflow_type: Literal["mlflow", "mlflow-sharinghub", "gitlab"],
mlflow_url: str | None,
auth_token: str,
) -> None:
if mlflow_url and any(
c.features.get("mlflow") == FeatureVal.ENABLE for c in project.categories
):
mlflow_url = clean_url(mlflow_url)
registered_models = []
match mlflow_type:
case "mlflow-sharinghub":
tracking_uri = f"{mlflow_url}{project.path}/tracking"
mlflow_api_url = tracking_uri + "/api/2.0/mlflow"
registered_models.extend(
await _get_registered_models(mlflow_api_url, auth_token=auth_token)
)
project.mlflow = MLflow(
tracking_uri=tracking_uri,
registered_models=registered_models,
)


async def _get_registered_models(
mlflow_api_url: str, auth_token: str
) -> list[RegisteredModel]:
registered_models = []
req_url = mlflow_api_url + "/registered-models/search"
async with AiohttpClient() as client:
headers = {"Authorization": f"Bearer {auth_token}"}
response = await client.get(req_url, headers=headers)
if response.ok:
data = await response.json()
all_registered_models = data.get("registered_models", [])
for rm_data in all_registered_models:
if latest_versions := rm_data.get("latest_versions", []):
latest_version = latest_versions[0]
model_name = rm_data["name"]
model_version = latest_version["version"]
registered_models.append(
RegisteredModel(
name=model_name,
latest_version=model_version,
mlflow_uri=f"models:/{model_name}/{model_version}",
)
)
return registered_models

0 comments on commit b547797

Please sign in to comment.