Skip to content

Commit

Permalink
Merge pull request #304 from populationgenomics/sept-23-upstream-3-1c…
Browse files Browse the repository at this point in the history
…28203

Sept 23 upstream 3 (1c28203)
  • Loading branch information
illusional authored Sep 19, 2023
2 parents ded82fe + e43761b commit 9fc7f9f
Show file tree
Hide file tree
Showing 54 changed files with 792 additions and 360 deletions.
49 changes: 45 additions & 4 deletions auth/auth/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import os
import random
from typing import Any, Awaitable, Callable, Dict, List
from typing import Any, Awaitable, Callable, Dict, List, Optional

import aiohttp
import kubernetes_asyncio.client
Expand All @@ -16,6 +16,8 @@
from gear.cloud_config import get_gcp_config, get_global_config
from hailtop import aiotools, httpx
from hailtop import batch_client as bc
from hailtop.aiocloud.aioazure import AzureGraphClient
from hailtop.aiocloud.aiogoogle import GoogleIAmClient
from hailtop.utils import secret_alnum_string, time_msecs

log = logging.getLogger('auth.driver')
Expand Down Expand Up @@ -140,10 +142,14 @@ async def delete(self):


class GSAResource:
def __init__(self, iam_client, gsa_email=None):
def __init__(self, iam_client: GoogleIAmClient, gsa_email: Optional[str] = None):
self.iam_client = iam_client
self.gsa_email = gsa_email

async def get_unique_id(self) -> str:
service_account = await self.iam_client.get(f'/serviceAccounts/{self.gsa_email}')
return service_account['uniqueId']

async def create(self, username):
assert self.gsa_email is None

Expand All @@ -164,7 +170,7 @@ async def create(self, username):

async def _delete(self, gsa_email):
try:
await self.iam_client.delete(f'/serviceAccounts/{gsa_email}/keys')
await self.iam_client.delete(f'/serviceAccounts/{gsa_email}')
except aiohttp.ClientResponseError as e:
if e.status == 404:
pass
Expand All @@ -179,10 +185,17 @@ async def delete(self):


class AzureServicePrincipalResource:
def __init__(self, graph_client, app_obj_id=None):
def __init__(self, graph_client: AzureGraphClient, app_obj_id: Optional[str] = None):
self.graph_client = graph_client
self.app_obj_id = app_obj_id

async def get_service_principal_object_id(self) -> str:
assert self.app_obj_id
app = await self.graph_client.get(f'/applications/{self.app_obj_id}')
app_id = app['appId']
service_principal = await self.graph_client.get(f"/servicePrincipals(appId='{app_id}')")
return service_principal['id']

async def create(self, username):
assert self.app_obj_id is None

Expand Down Expand Up @@ -496,6 +509,28 @@ async def delete_user(app, user):
)


async def resolve_identity_uid(app, hail_identity):
id_client = app['identity_client']
db = app['db']

if CLOUD == 'gcp':
gsa = GSAResource(id_client, hail_identity)
hail_identity_uid = await gsa.get_unique_id()
else:
assert CLOUD == 'azure'
sp = AzureServicePrincipalResource(id_client, hail_identity)
hail_identity_uid = await sp.get_service_principal_object_id()

await db.just_execute(
'''
UPDATE users
SET hail_identity_uid = %s
WHERE hail_identity = %s
''',
(hail_identity_uid, hail_identity),
)


async def update_users(app):
log.info('in update_users')

Expand All @@ -511,6 +546,12 @@ async def update_users(app):
for user in deleting_users:
await delete_user(app, user)

users_without_hail_identity_uid = [
x async for x in db.execute_and_fetchall('SELECT * FROM users WHERE hail_identity_uid IS NULL')
]
for user in users_without_hail_identity_uid:
await resolve_identity_uid(app, user['hail_identity'])

return True


Expand Down
1 change: 1 addition & 0 deletions auth/sql/add-hail-identity-uid.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE `users` ADD COLUMN `hail_identity_uid` VARCHAR(300) DEFAULT NULL;
1 change: 1 addition & 0 deletions auth/sql/estimated-current.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ CREATE TABLE `users` (
`tokens_secret_name` varchar(255) DEFAULT NULL,
-- identity
`hail_identity` varchar(255) DEFAULT NULL,
`hail_identity_uid` VARCHAR(255) DEFAULT NULL,
`hail_credentials_secret_name` varchar(255) DEFAULT NULL,
-- namespace, for developers
`namespace_name` varchar(255) DEFAULT NULL,
Expand Down
9 changes: 7 additions & 2 deletions batch/batch/driver/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@
from ..file_store import FileStore
from ..globals import HTTP_CLIENT_MAX_SIZE
from ..inst_coll_config import InstanceCollectionConfigs, PoolConfig
from ..utils import authorization_token, batch_only, json_to_value, query_billing_projects
from ..utils import (
authorization_token,
batch_only,
json_to_value,
query_billing_projects_with_cost,
)
from .canceller import Canceller
from .driver import CloudDriver
from .instance_collection import InstanceCollectionManager, JobPrivateInstanceManager, Pool
Expand Down Expand Up @@ -1185,7 +1190,7 @@ async def _cancel_batch(app, batch_id):
async def monitor_billing_limits(app):
db: Database = app['db']

records = await query_billing_projects(db)
records = await query_billing_projects_with_cost(db)
for record in records:
limit = record['limit']
accrued_cost = record['accrued_cost']
Expand Down
24 changes: 18 additions & 6 deletions batch/batch/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@
from ..inst_coll_config import InstanceCollectionConfigs
from ..resource_usage import ResourceUsageMonitor
from ..spec_writer import SpecWriter
from ..utils import query_billing_projects, regions_to_bits_rep, unavailable_if_frozen
from ..utils import (
query_billing_projects_with_cost,
query_billing_projects_without_cost,
regions_to_bits_rep,
unavailable_if_frozen,
)
from .query import CURRENT_QUERY_VERSION, build_batch_jobs_query
from .validate import ValidationError, validate_and_clean_jobs, validate_batch, validate_batch_update

Expand Down Expand Up @@ -2418,9 +2423,16 @@ async def ui_get_billing_limits(request, userdata):
else:
user = None

billing_projects = await query_billing_projects(db, user=user)
billing_projects = await query_billing_projects_with_cost(db, user=user)

open_billing_projects = [bp for bp in billing_projects if bp['status'] == 'open']
closed_billing_projects = [bp for bp in billing_projects if bp['status'] == 'closed']

page_context = {'billing_projects': billing_projects, 'is_developer': userdata['is_developer']}
page_context = {
'open_billing_projects': open_billing_projects,
'closed_billing_projects': closed_billing_projects,
'is_developer': userdata['is_developer'],
}
return await render_template('batch', request, userdata, 'billing_limits.html', page_context)


Expand Down Expand Up @@ -2617,7 +2629,7 @@ async def ui_get_billing(request, userdata):
@catch_ui_error_in_dev
async def ui_get_billing_projects(request, userdata):
db: Database = request.app['db']
billing_projects = await query_billing_projects(db)
billing_projects = await query_billing_projects_without_cost(db)
page_context = {
'billing_projects': [{**p, 'size': len(p['users'])} for p in billing_projects if p['status'] == 'open'],
'closed_projects': [p for p in billing_projects if p['status'] == 'closed'],
Expand All @@ -2635,7 +2647,7 @@ async def get_billing_projects(request, userdata):
else:
user = None

billing_projects = await query_billing_projects(db, user=user)
billing_projects = await query_billing_projects_with_cost(db, user=user)
return json_response(billing_projects)


Expand All @@ -2650,7 +2662,7 @@ async def get_billing_project(request, userdata):
else:
user = None

billing_projects = await query_billing_projects(db, user=user, billing_project=billing_project)
billing_projects = await query_billing_projects_with_cost(db, user=user, billing_project=billing_project)

if not billing_projects:
raise web.HTTPForbidden(reason=f'Unknown Hail Batch billing project {billing_project}.')
Expand Down
42 changes: 40 additions & 2 deletions batch/batch/front_end/templates/billing_limits.html
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
{% block title %}Billing Limits{% endblock %}
{% block content %}
<h1>Billing Project Limits</h1>
{% if open_billing_projects %}
<div class='flex-col' style="overflow: auto;">
<table class="data-table" id="billing_limits">
<h2>Open Projects</h2>
<table class="data-table" id="open-billing-limits">
<thead>
<tr>
<th>Billing Project</th>
Expand All @@ -12,7 +14,7 @@ <h1>Billing Project Limits</h1>
</tr>
</thead>
<tbody>
{% for row in billing_projects %}
{% for row in open_billing_projects %}
<tr>
<td>{{ row['billing_project'] }}</td>
<td>{{ row['accrued_cost'] }}</td>
Expand All @@ -34,4 +36,40 @@ <h1>Billing Project Limits</h1>
</tbody>
</table>
</div>
{% endif %}
{% if closed_billing_projects %}
<div class='flex-col' style="overflow: auto;">
<h2>Closed Projects</h2>
<table class="data-table" id="closed-billing-limits">
<thead>
<tr>
<th>Billing Project</th>
<th>Accrued Cost</th>
<th>Limit</th>
</tr>
</thead>
<tbody>
{% for row in closed_billing_projects %}
<tr>
<td>{{ row['billing_project'] }}</td>
<td>{{ row['accrued_cost'] }}</td>
{% if is_developer %}
<td>
<form action="{{ base_path }}/billing_limits/{{ row['billing_project'] }}/edit" method="POST">
<input type="hidden" name="_csrf" value="{{ csrf_token }}">
<input type="text" required name="limit" value="{{ row['limit'] }}">
<button>
Edit
</button>
</form>
</td>
{% else %}
<td>{{ row['limit'] }}</td>
{% endif %}
</tr>
{% endfor %}
</tbody>
</table>
</div>
{% endif %}
{% endblock %}
56 changes: 46 additions & 10 deletions batch/batch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,9 @@ def __repr__(self):
return f'global {self._global_counter}'


async def query_billing_projects(db, user=None, billing_project=None):
args = []

async def query_billing_projects_with_cost(db, user=None, billing_project=None):
where_conditions = ["billing_projects.`status` != 'deleted'"]
args = []

if user:
where_conditions.append("JSON_CONTAINS(users, JSON_QUOTE(%s))")
Expand Down Expand Up @@ -161,14 +160,51 @@ async def query_billing_projects(db, user=None, billing_project=None):
LOCK IN SHARE MODE;
'''

def record_to_dict(record):
if record['users'] is None:
record['users'] = []
else:
record['users'] = json.loads(record['users'])
return record
billing_projects = []
async for record in db.execute_and_fetchall(sql, tuple(args)):
record['users'] = json.loads(record['users']) if record['users'] is not None else []
billing_projects.append(record)

return billing_projects


async def query_billing_projects_without_cost(db, user=None, billing_project=None):
where_conditions = ["billing_projects.`status` != 'deleted'"]
args = []

if user:
where_conditions.append("JSON_CONTAINS(users, JSON_QUOTE(%s))")
args.append(user)

if billing_project:
where_conditions.append('billing_projects.name_cs = %s')
args.append(billing_project)

if where_conditions:
where_condition = f'WHERE {" AND ".join(where_conditions)}'
else:
where_condition = ''

sql = f'''
SELECT billing_projects.name as billing_project,
billing_projects.`status` as `status`,
users, `limit`
FROM billing_projects
LEFT JOIN LATERAL (
SELECT billing_project, JSON_ARRAYAGG(`user_cs`) as users
FROM billing_project_users
WHERE billing_project_users.billing_project = billing_projects.name
GROUP BY billing_project_users.billing_project
LOCK IN SHARE MODE
) AS t ON TRUE
{where_condition}
LOCK IN SHARE MODE;
'''

billing_projects = [record_to_dict(record) async for record in db.execute_and_fetchall(sql, tuple(args))]
billing_projects = []
async for record in db.execute_and_fetchall(sql, tuple(args)):
record['users'] = json.loads(record['users']) if record['users'] is not None else []
billing_projects.append(record)

return billing_projects

Expand Down
Loading

0 comments on commit 9fc7f9f

Please sign in to comment.