Skip to content

Commit

Permalink
feat(InferenceManager): handle podman connections (#1530)
Browse files Browse the repository at this point in the history
* feat(InferenceManager): handle podman connections

Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com>

* test: ensuring connection propagation

Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com>

---------

Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com>
  • Loading branch information
axel7083 committed Aug 13, 2024
1 parent fac1f6a commit d302152
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 117 deletions.
2 changes: 1 addition & 1 deletion packages/backend/src/studio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ export class Studio {
),
);
this.#extensionContext.subscriptions.push(
this.#inferenceProviderRegistry.register(new WhisperCpp(this.#taskRegistry)),
this.#inferenceProviderRegistry.register(new WhisperCpp(this.#taskRegistry, this.#podmanConnection)),
);

/**
Expand Down
10 changes: 7 additions & 3 deletions packages/backend/src/utils/inferenceUtils.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { getFreeRandomPort } from './ports';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import type { InferenceServer, InferenceServerStatus } from '@shared/src/models/IInference';
import { InferenceType } from '@shared/src/models/IInference';
import type { ContainerProviderConnectionInfo } from '@shared/src/models/IContainerConnectionInfo';

vi.mock('./ports', () => ({
getFreeRandomPort: vi.fn(),
Expand All @@ -46,14 +47,17 @@ describe('withDefaultConfiguration', () => {
expect(result.port).toBe(8888);
expect(result.image).toBe(undefined);
expect(result.labels).toStrictEqual({});
expect(result.providerId).toBe(undefined);
expect(result.connection).toBe(undefined);
});

test('expect no default values', async () => {
const connectionMock = {
name: 'Dummy Connection',
} as unknown as ContainerProviderConnectionInfo;
const result = await withDefaultConfiguration({
modelsInfo: [{ id: 'dummyId' } as unknown as ModelInfo],
port: 9999,
providerId: 'dummyProviderId',
connection: connectionMock,
image: 'random-image',
labels: { hello: 'world' },
});
Expand All @@ -63,7 +67,7 @@ describe('withDefaultConfiguration', () => {
expect(result.port).toBe(9999);
expect(result.image).toBe('random-image');
expect(result.labels).toStrictEqual({ hello: 'world' });
expect(result.providerId).toBe('dummyProviderId');
expect(result.connection).toBe(connectionMock);
});
});

Expand Down
31 changes: 1 addition & 30 deletions packages/backend/src/utils/inferenceUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ import {
type ContainerProviderConnection,
type ImageInfo,
type ListImagesOptions,
provider,
type ProviderContainerConnection,
type PullEvent,
} from '@podman-desktop/api';
import type { CreationInferenceServerOptions, InferenceServerConfig } from '@shared/src/models/InferenceServerConfig';
Expand All @@ -31,33 +29,6 @@ import type { ModelInfo } from '@shared/src/models/IModelInfo';

export const LABEL_INFERENCE_SERVER: string = 'ai-lab-inference-server';

/**
* Return container connection provider
*/
export function getProviderContainerConnection(providerId?: string): ProviderContainerConnection {
// Get started providers
const providers = provider
.getContainerConnections()
.filter(connection => connection.connection.status() === 'started');

if (providers.length === 0) throw new Error('no engine started could be find.');

let output: ProviderContainerConnection | undefined = undefined;

// If we expect a specific engine
if (providerId !== undefined) {
output = providers.find(engine => engine.providerId === providerId);
} else {
// Have a preference for a podman engine
output = providers.find(engine => engine.connection.type === 'podman');
if (output === undefined) {
output = providers[0];
}
}
if (output === undefined) throw new Error('cannot find any started container provider.');
return output;
}

/**
* Given an image name, it will return the ImageInspectInfo corresponding. Will raise an error if not found.
* @param connection
Expand Down Expand Up @@ -107,7 +78,7 @@ export async function withDefaultConfiguration(
image: options.image,
labels: options.labels || {},
modelsInfo: options.modelsInfo,
providerId: options.providerId,
connection: options.connection,
inferenceProvider: options.inferenceProvider,
gpuLayers: options.gpuLayers !== undefined ? options.gpuLayers : -1,
};
Expand Down
32 changes: 11 additions & 21 deletions packages/backend/src/workers/provider/InferenceProvider.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,14 @@ import { beforeEach, describe, expect, test, vi } from 'vitest';
import type { TaskRegistry } from '../../registries/TaskRegistry';
import { type BetterContainerCreateResult, InferenceProvider } from './InferenceProvider';
import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig';
import type {
ContainerCreateOptions,
ContainerProviderConnection,
ImageInfo,
ProviderContainerConnection,
} from '@podman-desktop/api';
import type { ContainerCreateOptions, ContainerProviderConnection, ImageInfo } from '@podman-desktop/api';
import { containerEngine } from '@podman-desktop/api';
import { getImageInfo, getProviderContainerConnection } from '../../utils/inferenceUtils';
import { getImageInfo } from '../../utils/inferenceUtils';
import type { TaskState } from '@shared/src/models/ITask';
import type { InferenceServer } from '@shared/src/models/IInference';
import { InferenceType } from '@shared/src/models/IInference';

vi.mock('../../utils/inferenceUtils', () => ({
getProviderContainerConnection: vi.fn(),
getImageInfo: vi.fn(),
LABEL_INFERENCE_SERVER: 'ai-lab-inference-server',
}));
Expand All @@ -44,14 +38,6 @@ vi.mock('@podman-desktop/api', () => ({
},
}));

const DummyProviderContainerConnection: ProviderContainerConnection = {
providerId: 'dummy-provider-id',
connection: {
name: 'dummy-provider-connection',
type: 'podman',
} as unknown as ContainerProviderConnection,
};

const DummyImageInfo: ImageInfo = {
Id: 'dummy-image-id',
engineId: 'dummy-engine-id',
Expand All @@ -62,6 +48,11 @@ const taskRegistry: TaskRegistry = {
updateTask: vi.fn(),
} as unknown as TaskRegistry;

const connectionMock: ContainerProviderConnection = {
name: 'Dummy Connection',
type: 'podman',
} as unknown as ContainerProviderConnection;

class TestInferenceProvider extends InferenceProvider {
constructor() {
super(taskRegistry, InferenceType.NONE, 'test-inference-provider');
Expand All @@ -71,8 +62,8 @@ class TestInferenceProvider extends InferenceProvider {
throw new Error('not implemented');
}

publicPullImage(providerId: string | undefined, image: string, labels: { [id: string]: string }) {
return super.pullImage(providerId, image, labels);
publicPullImage(connection: ContainerProviderConnection, image: string, labels: { [id: string]: string }) {
return super.pullImage(connection, image, labels);
}

async publicCreateContainer(
Expand All @@ -96,7 +87,6 @@ class TestInferenceProvider extends InferenceProvider {
beforeEach(() => {
vi.resetAllMocks();

vi.mocked(getProviderContainerConnection).mockReturnValue(DummyProviderContainerConnection);
vi.mocked(getImageInfo).mockResolvedValue(DummyImageInfo);
vi.mocked(taskRegistry.createTask).mockImplementation(
(name: string, state: TaskState, labels: { [id: string]: string } = {}) => ({
Expand All @@ -115,7 +105,7 @@ beforeEach(() => {
describe('pullImage', () => {
test('should create a task and mark as success on completion', async () => {
const provider = new TestInferenceProvider();
await provider.publicPullImage('dummy-provider-id', 'dummy-image', {
await provider.publicPullImage(connectionMock, 'dummy-image', {
key: 'value',
});

Expand All @@ -138,7 +128,7 @@ describe('pullImage', () => {
vi.mocked(getImageInfo).mockRejectedValue(new Error('dummy test error'));

await expect(
provider.publicPullImage('dummy-provider-id', 'dummy-image', {
provider.publicPullImage(connectionMock, 'dummy-image', {
key: 'value',
}),
).rejects.toThrowError('dummy test error');
Expand Down
12 changes: 5 additions & 7 deletions packages/backend/src/workers/provider/InferenceProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import type {
ContainerCreateOptions,
ContainerCreateResult,
ContainerProviderConnection,
Disposable,
ImageInfo,
PullEvent,
Expand All @@ -26,7 +27,7 @@ import { containerEngine } from '@podman-desktop/api';
import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig';
import type { IWorker } from '../IWorker';
import type { TaskRegistry } from '../../registries/TaskRegistry';
import { getImageInfo, getProviderContainerConnection } from '../../utils/inferenceUtils';
import { getImageInfo } from '../../utils/inferenceUtils';
import type { InferenceServer, InferenceType } from '@shared/src/models/IInference';

export type BetterContainerCreateResult = ContainerCreateResult & { engineId: string };
Expand Down Expand Up @@ -77,24 +78,21 @@ export abstract class InferenceProvider implements IWorker<InferenceServerConfig

/**
* This method allows to pull the image, while creating a task for the user to follow progress
* @param providerId
* @param connection
* @param image
* @param labels
* @protected
*/
protected pullImage(
providerId: string | undefined,
connection: ContainerProviderConnection,
image: string,
labels: { [id: string]: string },
): Promise<ImageInfo> {
// Creating a task to follow pulling progress
const pullingTask = this.taskRegistry.createTask(`Pulling ${image}.`, 'loading', labels);

// Get the provider
const provider = getProviderContainerConnection(providerId);

// get the default image info for this provider
return getImageInfo(provider.connection, image, (_event: PullEvent) => {})
return getImageInfo(connection, image, (_event: PullEvent) => {})
.catch((err: unknown) => {
pullingTask.state = 'error';
pullingTask.progress = undefined;
Expand Down
Loading

0 comments on commit d302152

Please sign in to comment.