Skip to content

Commit

Permalink
Python: Reduce memory usage when restoring snapshot
Browse files Browse the repository at this point in the history
Previously we paid for two copies of the snapshot memory: one copy in the wasm
linear memory itself and a second copy in `BUNDLE_MEMORY_SNAPSHOT`. This ensures
that we never have more memory than one copy of the linear memory heap by
copying the memory directly from the snapshot to the linear memory. We also
release the C++ memory when we are done with it.
  • Loading branch information
hoodmane committed Mar 22, 2024
1 parent fac7b76 commit 84687f2
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 67 deletions.
7 changes: 6 additions & 1 deletion src/pyodide/internal/metadata.js
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import { default as MetadataReader } from "pyodide-internal:runtime-generated/metadata";
export { default as LOCKFILE } from "pyodide-internal:generated/pyodide-lock.json";
import { default as PYODIDE_BUCKET } from "pyodide-internal:generated/pyodide-bucket.json";
import { default as ArtifactBundler } from "pyodide-internal:artifacts";

export const IS_WORKERD = MetadataReader.isWorkerd();
export const IS_TRACING = MetadataReader.isTracing();
export const WORKERD_INDEX_URL = PYODIDE_BUCKET.PYODIDE_PACKAGE_BUCKET_URL;
export const REQUIREMENTS = MetadataReader.getRequirements();
export const MAIN_MODULE_NAME = MetadataReader.getMainModule();
export const BUNDLE_MEMORY_SNAPSHOT = MetadataReader.getMemorySnapshot();
export const MEMORY_SNAPSHOT_READER = MetadataReader.hasMemorySnapshot()
? MetadataReader
: ArtifactBundler.hasMemorySnapshot()
? ArtifactBundler
: undefined;
61 changes: 32 additions & 29 deletions src/pyodide/internal/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
} from "pyodide-internal:setupPackages";
import { default as TarReader } from "pyodide-internal:packages_tar_reader";
import processScriptImports from "pyodide-internal:process_script_imports.py";
import { BUNDLE_MEMORY_SNAPSHOT } from "pyodide-internal:metadata";
import { MEMORY_SNAPSHOT_READER } from "pyodide-internal:metadata";

/**
* This file is a simplified version of the Pyodide loader:
Expand Down Expand Up @@ -38,7 +38,8 @@ import pyodideWasmModule from "pyodide-internal:generated/pyodide.asm.wasm";
*/
import stdlib from "pyodide-internal:generated/python_stdlib.zip";

const SHOULD_UPLOAD_SNAPSHOT = ArtifactBundler.isEnabled() || ArtifactBundler.isEwValidating();
const SHOULD_UPLOAD_SNAPSHOT =
ArtifactBundler.isEnabled() || ArtifactBundler.isEwValidating();
const DEDICATED_SNAPSHOT = true;

/**
Expand All @@ -47,7 +48,9 @@ const DEDICATED_SNAPSHOT = true;
* which is quite slow. Startup with snapshot is 3-5 times faster than without
* it.
*/
let MEMORY = undefined;
let READ_MEMORY = undefined;
let SNAPSHOT_SIZE = undefined;

/**
* Record the dlopen handles that are needed by the MEMORY.
*/
Expand Down Expand Up @@ -234,7 +237,10 @@ function getEmscriptenSettings(lockfile, indexURL) {
// important because the file system lives outside of linear memory.
preRun: [prepareFileSystem, setEnv, preloadDynamicLibs],
instantiateWasm,
noInitialRun: !!MEMORY, // skip running main() if we have a snapshot
// if SNAPSHOT_SIZE is defined, start with the linear memory big enough to
// fit the snapshot. If it's not defined, this falls back to the default.
INITIAL_MEMORY: SNAPSHOT_SIZE,
noInitialRun: !!READ_MEMORY, // skip running main() if we have a snapshot
API, // Pyodide requires we pass this in.
};
}
Expand Down Expand Up @@ -278,16 +284,8 @@ async function instantiateEmscriptenModule(emscriptenSettings) {
async function prepareWasmLinearMemory(Module) {
// Note: if we are restoring from a snapshot, runtime is not initialized yet.
mountLib(Module, SITE_PACKAGES_INFO);
if (MEMORY) {
if (!(MEMORY instanceof Uint8Array)) {
throw new TypeError("Expected MEMORY to be a Uint8Array");
}
// resize linear memory to fit our snapshot. I think `growMemory` only
// exists if `-sALLOW_MEMORY_GROWTH` is passed to the linker but we'll
// probably always do that.
Module.growMemory(MEMORY.byteLength);
// restore memory from snapshot
Module.HEAP8.set(MEMORY);
if (READ_MEMORY) {
READ_MEMORY(Module);
// Don't call adjustSysPath here: it was called in the other branch when we
// were creating the snapshot so the outcome of that is already baked in.
return;
Expand Down Expand Up @@ -440,13 +438,20 @@ function encodeSnapshot(heap, dsoJSON) {
/**
* Decode heap and dsoJSON from the memory snapshot artifact we downloaded
*/
function decodeSnapshot(memorySnapshot) {
const uint32View = new Uint32Array(memorySnapshot);
const snapshotOffset = uint32View[0];
const jsonLength = uint32View[1];
const jsonView = new Uint8Array(memorySnapshot, 8, jsonLength);
DSO_METADATA = JSON.parse(new TextDecoder().decode(jsonView));
MEMORY = new Uint8Array(memorySnapshot, snapshotOffset);
function decodeSnapshot() {
const buf = new Uint32Array(2);
MEMORY_SNAPSHOT_READER.readMemorySnapshot(0, buf);
const snapshotOffset = buf[0];
SNAPSHOT_SIZE = MEMORY_SNAPSHOT_READER.getMemorySnapshotSize() - snapshotOffset;
const jsonLength = buf[1];
const jsonBuf = new Uint8Array(jsonLength);
MEMORY_SNAPSHOT_READER.readMemorySnapshot(8, jsonBuf);
DSO_METADATA = JSON.parse(new TextDecoder().decode(jsonBuf));
READ_MEMORY = function(Module) {
// restore memory from snapshot
MEMORY_SNAPSHOT_READER.readMemorySnapshot(snapshotOffset, Module.HEAP8);
MEMORY_SNAPSHOT_READER.disposeMemorySnapshot();
}
}

/**
Expand Down Expand Up @@ -492,24 +497,22 @@ function simpleRunPython(emscriptenModule, code) {
let TEST_SNAPSHOT = undefined;
(function () {
// Lookup memory snapshot from artifact store.
const memorySnapshot = BUNDLE_MEMORY_SNAPSHOT || ArtifactBundler.getMemorySnapshot();
if (!memorySnapshot) {
if (!MEMORY_SNAPSHOT_READER) {
// snapshots are disabled or there isn't one yet
return;
}
if (memorySnapshot.constructor.name !== "ArrayBuffer") {
throw new TypeError("Expected snapshot to be an ArrayBuffer");
}

// Simple sanity check to ensure this snapshot isn't corrupted.
//
// TODO(later): we need better detection when this is corrupted. Right now the isolate will
// just die.
if (memorySnapshot.byteLength <= 100) {
TEST_SNAPSHOT = memorySnapshot;
const snapshotSize = MEMORY_SNAPSHOT_READER.getMemorySnapshotSize();
if (snapshotSize <= 100) {
TEST_SNAPSHOT = new Uint8Array(snapshotSize);
MEMORY_SNAPSHOT_READER.readMemorySnapshot(0, TEST_SNAPSHOT);
return;
}
decodeSnapshot(memorySnapshot);
decodeSnapshot();
})();

export async function loadPyodide(lockfile, indexURL) {
Expand Down
6 changes: 2 additions & 4 deletions src/pyodide/internal/setupPackages.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ const STDLIB_PACKAGES = Object.values(LOCKFILE.packages)
.filter(({ install_dir }) => install_dir === "stdlib")
.map(({ name }) => canonicalizePackageName(name));


/**
* This stitches together the view of the site packages directory. Each
* requirement corresponds to a folder in the original tar file. For each
Expand Down Expand Up @@ -186,6 +185,5 @@ function addPackageToLoad(lockfile, name, toLoad) {

export { REQUIREMENTS };
export const TRANSITIVE_REQUIREMENTS = getTransitiveRequirements();
export const [SITE_PACKAGES_INFO, SITE_PACKAGES_SO_FILES, USE_LOAD_PACKAGE] = buildSitePackages(
TRANSITIVE_REQUIREMENTS,
);
export const [SITE_PACKAGES_INFO, SITE_PACKAGES_SO_FILES, USE_LOAD_PACKAGE] =
buildSitePackages(TRANSITIVE_REQUIREMENTS);
10 changes: 8 additions & 2 deletions src/pyodide/python-entrypoint-helper.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
// This file is a BUILTIN module that provides the actual implementation for the
// python-entrypoint.js USER module.

import { loadPyodide, uploadArtifacts, getMemoryToUpload } from "pyodide-internal:python";
import {
loadPyodide,
uploadArtifacts,
getMemoryToUpload,
} from "pyodide-internal:python";
import { enterJaegerSpan } from "pyodide-internal:jaeger";
import {
REQUIREMENTS,
Expand Down Expand Up @@ -109,7 +113,9 @@ function getMainModule() {
mainModulePromise = (async function () {
const pyodide = await getPyodide();
await setupPackages(pyodide);
return enterJaegerSpan("pyimport_main_module", () => pyimportMainModule(pyodide));
return enterJaegerSpan("pyimport_main_module", () =>
pyimportMainModule(pyodide),
);
})();
return mainModulePromise;
});
Expand Down
38 changes: 25 additions & 13 deletions src/workerd/api/pyodide/pyodide.c++
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
#include "pyodide.h"
#include "kj/array.h"
#include "kj/common.h"
#include "kj/debug.h"

namespace workerd::api::pyodide {

int PackagesTarReader::read(jsg::Lock& js, int offset, kj::Array<kj::byte> buf) {
int tarSize = PYODIDE_PACKAGES_TAR->size();
if (offset >= tarSize || offset < 0) {
static int readToTarget(kj::ArrayPtr<const kj::byte> source, int offset, kj::ArrayPtr<kj::byte> buf) {
int size = source.size();
if (offset >= size || offset < 0) {
return 0;
}
int toCopy = buf.size();
if (tarSize - offset < toCopy) {
toCopy = tarSize - offset;
if (size - offset < toCopy) {
toCopy = size - offset;
}
memcpy(buf.begin(), &((*PYODIDE_PACKAGES_TAR)[0]) + offset, toCopy);
memcpy(buf.begin(), source.begin() + offset, toCopy);
return toCopy;
}

int PackagesTarReader::read(jsg::Lock& js, int offset, kj::Array<kj::byte> buf) {
return readToTarget(PYODIDE_PACKAGES_TAR.get(), offset, buf);
}

kj::Array<jsg::JsRef<jsg::JsString>> PyodideMetadataReader::getNames(jsg::Lock& js) {
auto builder = kj::heapArrayBuilder<jsg::JsRef<jsg::JsString>>(this->names.size());
for (auto i : kj::zeroTo(builder.capacity())) {
Expand Down Expand Up @@ -44,16 +51,21 @@ int PyodideMetadataReader::read(jsg::Lock& js, int index, int offset, kj::Array<
return 0;
}
auto& data = contents[index];
int dataSize = data.size();
if (offset >= dataSize || offset < 0) {
return readToTarget(data, offset, buf);
}

int PyodideMetadataReader::readMemorySnapshot(int offset, kj::Array<kj::byte> buf) {
if (memorySnapshot == kj::none) {
return 0;
}
int toCopy = buf.size();
if (dataSize - offset < toCopy) {
toCopy = dataSize - offset;
return readToTarget(KJ_REQUIRE_NONNULL(memorySnapshot), offset, buf);
}

int ArtifactBundler::readMemorySnapshot(int offset, kj::Array<kj::byte> buf) {
if (existingSnapshot == kj::none) {
return 0;
}
memcpy(buf.begin(), &data[0] + offset, toCopy);
return toCopy;
return readToTarget(KJ_REQUIRE_NONNULL(existingSnapshot), offset, buf);
}

jsg::Ref<PyodideMetadataReader> makePyodideMetadataReader(Worker::Reader conf) {
Expand Down
65 changes: 47 additions & 18 deletions src/workerd/api/pyodide/pyodide.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include "kj/array.h"
#include "kj/debug.h"
#include <kj/common.h>
#include <pyodide/generated/pyodide_extra.capnp.h>
#include <pyodide/pyodide.capnp.h>
Expand All @@ -23,6 +25,7 @@ class PackagesTarReader : public jsg::Object {
}
};


// A function to read a segment of the tar file into a buffer
// Set up this way to avoid copying files that aren't accessed.
class PyodideMetadataReader : public jsg::Object {
Expand Down Expand Up @@ -56,10 +59,6 @@ class PyodideMetadataReader : public jsg::Object {
return kj::str(this->mainModule);
}

kj::Maybe<kj::Array<kj::byte>> getMemorySnapshot() {
return kj::mv(memorySnapshot);
}

kj::Array<jsg::JsRef<jsg::JsString>> getNames(jsg::Lock& js);

kj::Array<jsg::JsRef<jsg::JsString>> getRequirements(jsg::Lock& js);
Expand All @@ -68,6 +67,21 @@ class PyodideMetadataReader : public jsg::Object {

int read(jsg::Lock& js, int index, int offset, kj::Array<kj::byte> buf);

bool hasMemorySnapshot() {
return memorySnapshot != kj::none;
}
int getMemorySnapshotSize() {
if (memorySnapshot == kj::none) {
return 0;
}
return KJ_REQUIRE_NONNULL(memorySnapshot).size();
}

void disposeMemorySnapshot() {
memorySnapshot = kj::none;
}
int readMemorySnapshot(int offset, kj::Array<kj::byte> buf);

JSG_RESOURCE_TYPE(PyodideMetadataReader) {
JSG_METHOD(isWorkerd);
JSG_METHOD(isTracing);
Expand All @@ -76,7 +90,10 @@ class PyodideMetadataReader : public jsg::Object {
JSG_METHOD(getNames);
JSG_METHOD(getSizes);
JSG_METHOD(read);
JSG_METHOD(getMemorySnapshot);
JSG_METHOD(hasMemorySnapshot);
JSG_METHOD(getMemorySnapshotSize);
JSG_METHOD(readMemorySnapshot);
JSG_METHOD(disposeMemorySnapshot);
}

void visitForMemoryInfo(jsg::MemoryTracker& tracker) const {
Expand All @@ -101,11 +118,13 @@ class ArtifactBundler : public jsg::Object {

ArtifactBundler(kj::Maybe<kj::Array<kj::byte>> existingSnapshot,
kj::Function<kj::Promise<bool>(kj::Array<kj::byte> snapshot)> uploadMemorySnapshotCb)
: storedSnapshot(kj::none),
:
storedSnapshot(kj::none),
existingSnapshot(kj::mv(existingSnapshot)),
uploadMemorySnapshotCb(kj::mv(uploadMemorySnapshotCb)),
hasUploaded(false),
isValidating(false) {};
isValidating(false)
{};

ArtifactBundler(kj::Maybe<kj::Array<kj::byte>> existingSnapshot)
: storedSnapshot(kj::none),
Expand All @@ -116,7 +135,7 @@ class ArtifactBundler : public jsg::Object {

ArtifactBundler(bool isValidating = false)
: storedSnapshot(kj::none),
existingSnapshot(kj::none),
existingSnapshot(kj::heapArray<kj::byte>(0)),
uploadMemorySnapshotCb(kj::none),
hasUploaded(false),
isValidating(isValidating) {};
Expand Down Expand Up @@ -144,13 +163,6 @@ class ArtifactBundler : public jsg::Object {
storedSnapshot = kj::mv(snapshot);
}

jsg::Optional<kj::Array<kj::byte>> getMemorySnapshot(jsg::Lock& js) {
KJ_IF_SOME(val, existingSnapshot) {
return kj::mv(val);
}
return kj::none;
}

bool isEnabled() {
return uploadMemorySnapshotCb != kj::none;
}
Expand All @@ -159,6 +171,19 @@ class ArtifactBundler : public jsg::Object {
return existingSnapshot != kj::none;
}

int getMemorySnapshotSize() {
if (existingSnapshot == kj::none) {
return 0;
}
return KJ_REQUIRE_NONNULL(existingSnapshot).size();
}

int readMemorySnapshot(int offset, kj::Array<kj::byte> buf);
void disposeMemorySnapshot() {
existingSnapshot = kj::none;
}


// Determines whether this ArtifactBundler was created inside the validator.
bool isEwValidating() {
return isValidating;
Expand All @@ -169,14 +194,18 @@ class ArtifactBundler : public jsg::Object {
}

void visitForMemoryInfo(jsg::MemoryTracker& tracker) const {
KJ_IF_SOME(snapshot, existingSnapshot) {
tracker.trackFieldWithSize("snapshot", snapshot.size());
if (existingSnapshot == kj::none) {
return;
}
tracker.trackFieldWithSize("snapshot", KJ_REQUIRE_NONNULL(existingSnapshot).size());
}

JSG_RESOURCE_TYPE(ArtifactBundler) {
JSG_METHOD(uploadMemorySnapshot);
JSG_METHOD(getMemorySnapshot);
JSG_METHOD(hasMemorySnapshot);
JSG_METHOD(getMemorySnapshotSize);
JSG_METHOD(readMemorySnapshot);
JSG_METHOD(disposeMemorySnapshot);
JSG_METHOD(isEnabled);
JSG_METHOD(isEwValidating);
JSG_METHOD(storeMemorySnapshot);
Expand Down

0 comments on commit 84687f2

Please sign in to comment.