Skip to content

Commit

Permalink
Refactor: get_total_cost_for_address + fix View (#466)
Browse files Browse the repository at this point in the history
Solutions:

- Move get_total_cost_for_address from accessors/vm.py to accesors/cost.py
- Add unit test
- Fix view use correct unit in view calculation (MB to MiB)
  • Loading branch information
1yam committed Aug 25, 2023
1 parent 1630ca7 commit 913e007
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def upgrade() -> None:
JOIN files f ON file_pins.file_hash::text = f.hash::text
WHERE file_pins.owner IS NOT NULL
GROUP BY file_pins.owner) storage ON vm_prices.owner::text = storage.owner::text,
LATERAL ( SELECT 3::numeric * storage.storage_size / 1000000::numeric AS total_storage_cost) sc,
LATERAL ( SELECT 3::numeric * storage.storage_size / 1048576::numeric AS total_storage_cost) sc,
LATERAL ( SELECT COALESCE(vm_prices.total_vm_cost, 0::double precision) +
COALESCE(sc.total_storage_cost, 0::numeric)::double precision AS total_cost) tc
"""
Expand Down
15 changes: 15 additions & 0 deletions src/aleph/db/accessors/cost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from decimal import Decimal
from typing import Optional
from sqlalchemy import select, func, text
from aleph.types.db_session import DbSession


def get_total_cost_for_address(session: DbSession, address: str) -> Decimal:
select_stmt = (
select(func.sum(text("total_cost")))
.select_from(text("public.costs_view"))
.where(text("address = :address"))
).params(address=address)

total_cost = session.execute(select_stmt).scalar()
return Decimal(total_cost) if total_cost is not None else Decimal(0)
11 changes: 0 additions & 11 deletions src/aleph/db/accessors/vms.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,3 @@ def refresh_vm_version(session: DbSession, vm_hash: str) -> None:
)
session.execute(delete(VmVersionDb).where(VmVersionDb.vm_hash == vm_hash))
session.execute(upsert_stmt)


def get_total_cost_for_address(session: DbSession, address: str) -> Decimal:
select_stmt = (
select(func.sum(text("total_cost")))
.select_from(text("public.costs_view"))
.where(text("address = :address"))
).params(address=address)

total_cost = session.execute(select_stmt).scalar()
return Decimal(total_cost) if total_cost is not None else Decimal(0)
2 changes: 1 addition & 1 deletion src/aleph/handlers/content/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from decimal import Decimal

from aleph.db.accessors.balances import get_total_balance
from aleph.db.accessors.cost import get_total_cost_for_address
from aleph.db.accessors.files import (
find_file_tags,
find_file_pins,
Expand All @@ -31,7 +32,6 @@
delete_vm_updates,
refresh_vm_version,
is_vm_amend_allowed,
get_total_cost_for_address,
)
from aleph.db.models import (
MessageDb,
Expand Down
210 changes: 210 additions & 0 deletions tests/db/test_cost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
from typing import List, Protocol

import pytest
import pytz
from aleph_message.models import (
Chain,
InstanceContent,
MessageType,
ItemType,
ExecutableContent,
ProgramContent,
)
from decimal import Decimal

from aleph_message.models.execution.volume import ImmutableVolume

from aleph.db.accessors.cost import get_total_cost_for_address
from aleph.db.accessors.files import insert_message_file_pin, upsert_file_tag
from aleph.db.models import (
AlephBalanceDb,
PendingMessageDb,
StoredFileDb,
MessageStatusDb,
)
import json
from aleph.toolkit.timestamp import timestamp_to_datetime
from aleph.types.db_session import DbSessionFactory, DbSession
import datetime as dt

from aleph.types.files import FileType, FileTag
from aleph.types.message_status import MessageStatus


class Volume(Protocol):
ref: str
use_latest: bool


def get_volume_refs(content: ExecutableContent) -> List[Volume]:
volumes = []

for volume in content.volumes:
if isinstance(volume, ImmutableVolume):
volumes.append(volume)

if isinstance(content, ProgramContent):
volumes += [content.code, content.runtime]
if content.data:
volumes.append(content.data)

elif isinstance(content, InstanceContent):
if parent := content.rootfs.parent:
volumes.append(parent)

return volumes


def insert_volume_refs(session: DbSession, message: PendingMessageDb):
"""
Insert volume references in the DB to make the program processable.
"""

content = InstanceContent.parse_raw(message.item_content)
volumes = get_volume_refs(content)

created = pytz.utc.localize(dt.datetime(2023, 1, 1))

for volume in volumes:
# Note: we use the reversed ref to generate the file hash for style points,
# but it could be set to any valid hash.
file_hash = volume.ref[::-1]

session.add(StoredFileDb(hash=file_hash, size=1024 * 1024, type=FileType.FILE))
session.flush()
insert_message_file_pin(
session=session,
file_hash=volume.ref[::-1],
owner=content.address,
item_hash=volume.ref,
ref=None,
created=created,
)
upsert_file_tag(
session=session,
tag=FileTag(volume.ref),
owner=content.address,
file_hash=volume.ref[::-1],
last_updated=created,
)


@pytest.fixture
def fixture_instance_message(session_factory: DbSessionFactory) -> PendingMessageDb:
content = {
"address": "0x9319Ad3B7A8E0eE24f2E639c40D8eD124C5520Ba",
"allow_amend": False,
"variables": {
"VM_CUSTOM_VARIABLE": "SOMETHING",
"VM_CUSTOM_VARIABLE_2": "32",
},
"environment": {
"reproducible": True,
"internet": False,
"aleph_api": False,
"shared_cache": False,
},
"resources": {"vcpus": 1, "memory": 128, "seconds": 30},
"requirements": {"cpu": {"architecture": "x86_64"}},
"rootfs": {
"parent": {
"ref": "549ec451d9b099cad112d4aaa2c00ac40fb6729a92ff252ff22eef0b5c3cb613",
"use_latest": True,
},
"persistence": "host",
"name": "test-rootfs",
"size_mib": 20 * 1024,
},
"authorized_keys": [
"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGULT6A41Msmw2KEu0R9MvUjhuWNAsbdeZ0DOwYbt4Qt user@example",
"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH0jqdc5dmt75QhTrWqeHDV9xN8vxbgFyOYs2fuQl7CI",
],
"volumes": [
{
"comment": "Python libraries. Read-only since a 'ref' is specified.",
"mount": "/opt/venv",
"ref": "5f31b0706f59404fad3d0bff97ef89ddf24da4761608ea0646329362c662ba51",
"use_latest": False,
},
{
"comment": "Ephemeral storage, read-write but will not persist after the VM stops",
"mount": "/var/cache",
"ephemeral": True,
"size_mib": 5,
},
{
"comment": "Working data persisted on the VM supervisor, not available on other nodes",
"mount": "/var/lib/sqlite",
"name": "sqlite-data",
"persistence": "host",
"size_mib": 10,
},
{
"comment": "Working data persisted on the Aleph network. "
"New VMs will try to use the latest version of this volume, "
"with no guarantee against conflicts",
"mount": "/var/lib/statistics",
"name": "statistics",
"persistence": "store",
"size_mib": 10,
},
{
"comment": "Raw drive to use by a process, do not mount it",
"name": "raw-data",
"persistence": "host",
"size_mib": 10,
},
],
"time": 1619017773.8950517,
}

pending_message = PendingMessageDb(
item_hash="734a1287a2b7b5be060312ff5b05ad1bcf838950492e3428f2ac6437a1acad26",
type=MessageType.instance,
chain=Chain.ETH,
sender="0x9319Ad3B7A8E0eE24f2E639c40D8eD124C5520Ba",
signature=None,
item_type=ItemType.inline,
item_content=json.dumps(content),
time=timestamp_to_datetime(1619017773.8950577),
channel=None,
reception_time=timestamp_to_datetime(1619017774),
fetched=True,
check_message=False,
retries=0,
next_attempt=dt.datetime(2023, 1, 1),
)
with session_factory() as session:
session.add(pending_message)
session.add(
MessageStatusDb(
item_hash=pending_message.item_hash,
status=MessageStatus.PENDING,
reception_time=pending_message.reception_time,
)
)
session.commit()

return pending_message


def test_get_total_cost_for_address(
session_factory: DbSessionFactory, fixture_instance_message
):
with session_factory() as session:
session.add(
AlephBalanceDb(
address="0xB68B9D4f3771c246233823ed1D3Add451055F9Ef",
chain=Chain.ETH,
dapp=None,
balance=Decimal(100_000),
eth_height=0,
)
)
insert_volume_refs(session, fixture_instance_message)
session.commit()

total_cost: Decimal = get_total_cost_for_address(
session=session, address=fixture_instance_message.sender
)
assert total_cost == Decimal(6)

0 comments on commit 913e007

Please sign in to comment.