Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(InferenceManager): handle podman connections #1530

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/backend/src/studio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ export class Studio {
),
);
this.#extensionContext.subscriptions.push(
this.#inferenceProviderRegistry.register(new WhisperCpp(this.#taskRegistry)),
this.#inferenceProviderRegistry.register(new WhisperCpp(this.#taskRegistry, this.#podmanConnection)),
);

/**
Expand Down
10 changes: 7 additions & 3 deletions packages/backend/src/utils/inferenceUtils.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { getFreeRandomPort } from './ports';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import type { InferenceServer, InferenceServerStatus } from '@shared/src/models/IInference';
import { InferenceType } from '@shared/src/models/IInference';
import type { ContainerProviderConnectionInfo } from '@shared/src/models/IContainerConnectionInfo';

vi.mock('./ports', () => ({
getFreeRandomPort: vi.fn(),
Expand All @@ -46,14 +47,17 @@ describe('withDefaultConfiguration', () => {
expect(result.port).toBe(8888);
expect(result.image).toBe(undefined);
expect(result.labels).toStrictEqual({});
expect(result.providerId).toBe(undefined);
expect(result.connection).toBe(undefined);
});

test('expect no default values', async () => {
const connectionMock = {
name: 'Dummy Connection',
} as unknown as ContainerProviderConnectionInfo;
const result = await withDefaultConfiguration({
modelsInfo: [{ id: 'dummyId' } as unknown as ModelInfo],
port: 9999,
providerId: 'dummyProviderId',
connection: connectionMock,
image: 'random-image',
labels: { hello: 'world' },
});
Expand All @@ -63,7 +67,7 @@ describe('withDefaultConfiguration', () => {
expect(result.port).toBe(9999);
expect(result.image).toBe('random-image');
expect(result.labels).toStrictEqual({ hello: 'world' });
expect(result.providerId).toBe('dummyProviderId');
expect(result.connection).toBe(connectionMock);
});
});

Expand Down
31 changes: 1 addition & 30 deletions packages/backend/src/utils/inferenceUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ import {
type ContainerProviderConnection,
type ImageInfo,
type ListImagesOptions,
provider,
type ProviderContainerConnection,
type PullEvent,
} from '@podman-desktop/api';
import type { CreationInferenceServerOptions, InferenceServerConfig } from '@shared/src/models/InferenceServerConfig';
Expand All @@ -31,33 +29,6 @@ import type { ModelInfo } from '@shared/src/models/IModelInfo';

export const LABEL_INFERENCE_SERVER: string = 'ai-lab-inference-server';

/**
* Return container connection provider
*/
export function getProviderContainerConnection(providerId?: string): ProviderContainerConnection {
// Get started providers
const providers = provider
.getContainerConnections()
.filter(connection => connection.connection.status() === 'started');

if (providers.length === 0) throw new Error('no engine started could be find.');

let output: ProviderContainerConnection | undefined = undefined;

// If we expect a specific engine
if (providerId !== undefined) {
output = providers.find(engine => engine.providerId === providerId);
} else {
// Have a preference for a podman engine
output = providers.find(engine => engine.connection.type === 'podman');
if (output === undefined) {
output = providers[0];
}
}
if (output === undefined) throw new Error('cannot find any started container provider.');
return output;
}

/**
* Given an image name, it will return the ImageInspectInfo corresponding. Will raise an error if not found.
* @param connection
Expand Down Expand Up @@ -107,7 +78,7 @@ export async function withDefaultConfiguration(
image: options.image,
labels: options.labels || {},
modelsInfo: options.modelsInfo,
providerId: options.providerId,
connection: options.connection,
inferenceProvider: options.inferenceProvider,
gpuLayers: options.gpuLayers !== undefined ? options.gpuLayers : -1,
};
Expand Down
32 changes: 11 additions & 21 deletions packages/backend/src/workers/provider/InferenceProvider.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,14 @@ import { beforeEach, describe, expect, test, vi } from 'vitest';
import type { TaskRegistry } from '../../registries/TaskRegistry';
import { type BetterContainerCreateResult, InferenceProvider } from './InferenceProvider';
import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig';
import type {
ContainerCreateOptions,
ContainerProviderConnection,
ImageInfo,
ProviderContainerConnection,
} from '@podman-desktop/api';
import type { ContainerCreateOptions, ContainerProviderConnection, ImageInfo } from '@podman-desktop/api';
import { containerEngine } from '@podman-desktop/api';
import { getImageInfo, getProviderContainerConnection } from '../../utils/inferenceUtils';
import { getImageInfo } from '../../utils/inferenceUtils';
import type { TaskState } from '@shared/src/models/ITask';
import type { InferenceServer } from '@shared/src/models/IInference';
import { InferenceType } from '@shared/src/models/IInference';

vi.mock('../../utils/inferenceUtils', () => ({
getProviderContainerConnection: vi.fn(),
getImageInfo: vi.fn(),
LABEL_INFERENCE_SERVER: 'ai-lab-inference-server',
}));
Expand All @@ -44,14 +38,6 @@ vi.mock('@podman-desktop/api', () => ({
},
}));

const DummyProviderContainerConnection: ProviderContainerConnection = {
providerId: 'dummy-provider-id',
connection: {
name: 'dummy-provider-connection',
type: 'podman',
} as unknown as ContainerProviderConnection,
};

const DummyImageInfo: ImageInfo = {
Id: 'dummy-image-id',
engineId: 'dummy-engine-id',
Expand All @@ -62,6 +48,11 @@ const taskRegistry: TaskRegistry = {
updateTask: vi.fn(),
} as unknown as TaskRegistry;

const connectionMock: ContainerProviderConnection = {
name: 'Dummy Connection',
type: 'podman',
} as unknown as ContainerProviderConnection;

class TestInferenceProvider extends InferenceProvider {
constructor() {
super(taskRegistry, InferenceType.NONE, 'test-inference-provider');
Expand All @@ -71,8 +62,8 @@ class TestInferenceProvider extends InferenceProvider {
throw new Error('not implemented');
}

publicPullImage(providerId: string | undefined, image: string, labels: { [id: string]: string }) {
return super.pullImage(providerId, image, labels);
publicPullImage(connection: ContainerProviderConnection, image: string, labels: { [id: string]: string }) {
return super.pullImage(connection, image, labels);
}

async publicCreateContainer(
Expand All @@ -96,7 +87,6 @@ class TestInferenceProvider extends InferenceProvider {
beforeEach(() => {
vi.resetAllMocks();

vi.mocked(getProviderContainerConnection).mockReturnValue(DummyProviderContainerConnection);
vi.mocked(getImageInfo).mockResolvedValue(DummyImageInfo);
vi.mocked(taskRegistry.createTask).mockImplementation(
(name: string, state: TaskState, labels: { [id: string]: string } = {}) => ({
Expand All @@ -115,7 +105,7 @@ beforeEach(() => {
describe('pullImage', () => {
test('should create a task and mark as success on completion', async () => {
const provider = new TestInferenceProvider();
await provider.publicPullImage('dummy-provider-id', 'dummy-image', {
await provider.publicPullImage(connectionMock, 'dummy-image', {
key: 'value',
});

Expand All @@ -138,7 +128,7 @@ describe('pullImage', () => {
vi.mocked(getImageInfo).mockRejectedValue(new Error('dummy test error'));

await expect(
provider.publicPullImage('dummy-provider-id', 'dummy-image', {
provider.publicPullImage(connectionMock, 'dummy-image', {
key: 'value',
}),
).rejects.toThrowError('dummy test error');
Expand Down
12 changes: 5 additions & 7 deletions packages/backend/src/workers/provider/InferenceProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import type {
ContainerCreateOptions,
ContainerCreateResult,
ContainerProviderConnection,
Disposable,
ImageInfo,
PullEvent,
Expand All @@ -26,7 +27,7 @@ import { containerEngine } from '@podman-desktop/api';
import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig';
import type { IWorker } from '../IWorker';
import type { TaskRegistry } from '../../registries/TaskRegistry';
import { getImageInfo, getProviderContainerConnection } from '../../utils/inferenceUtils';
import { getImageInfo } from '../../utils/inferenceUtils';
import type { InferenceServer, InferenceType } from '@shared/src/models/IInference';

export type BetterContainerCreateResult = ContainerCreateResult & { engineId: string };
Expand Down Expand Up @@ -77,24 +78,21 @@ export abstract class InferenceProvider implements IWorker<InferenceServerConfig

/**
* This method allows to pull the image, while creating a task for the user to follow progress
* @param providerId
* @param connection
* @param image
* @param labels
* @protected
*/
protected pullImage(
providerId: string | undefined,
connection: ContainerProviderConnection,
image: string,
labels: { [id: string]: string },
): Promise<ImageInfo> {
// Creating a task to follow pulling progress
const pullingTask = this.taskRegistry.createTask(`Pulling ${image}.`, 'loading', labels);

// Get the provider
const provider = getProviderContainerConnection(providerId);

// get the default image info for this provider
return getImageInfo(provider.connection, image, (_event: PullEvent) => {})
return getImageInfo(connection, image, (_event: PullEvent) => {})
.catch((err: unknown) => {
pullingTask.state = 'error';
pullingTask.progress = undefined;
Expand Down
Loading