diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index 476a0fdee..99f5ccfd6 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -267,7 +267,7 @@ export class Studio { ), ); this.#extensionContext.subscriptions.push( - this.#inferenceProviderRegistry.register(new WhisperCpp(this.#taskRegistry)), + this.#inferenceProviderRegistry.register(new WhisperCpp(this.#taskRegistry, this.#podmanConnection)), ); /** diff --git a/packages/backend/src/utils/inferenceUtils.spec.ts b/packages/backend/src/utils/inferenceUtils.spec.ts index df7cb601b..c9f069dce 100644 --- a/packages/backend/src/utils/inferenceUtils.spec.ts +++ b/packages/backend/src/utils/inferenceUtils.spec.ts @@ -21,6 +21,7 @@ import { getFreeRandomPort } from './ports'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; import type { InferenceServer, InferenceServerStatus } from '@shared/src/models/IInference'; import { InferenceType } from '@shared/src/models/IInference'; +import type { ContainerProviderConnectionInfo } from '@shared/src/models/IContainerConnectionInfo'; vi.mock('./ports', () => ({ getFreeRandomPort: vi.fn(), @@ -46,14 +47,17 @@ describe('withDefaultConfiguration', () => { expect(result.port).toBe(8888); expect(result.image).toBe(undefined); expect(result.labels).toStrictEqual({}); - expect(result.providerId).toBe(undefined); + expect(result.connection).toBe(undefined); }); test('expect no default values', async () => { + const connectionMock = { + name: 'Dummy Connection', + } as unknown as ContainerProviderConnectionInfo; const result = await withDefaultConfiguration({ modelsInfo: [{ id: 'dummyId' } as unknown as ModelInfo], port: 9999, - providerId: 'dummyProviderId', + connection: connectionMock, image: 'random-image', labels: { hello: 'world' }, }); @@ -63,7 +67,7 @@ describe('withDefaultConfiguration', () => { expect(result.port).toBe(9999); expect(result.image).toBe('random-image'); expect(result.labels).toStrictEqual({ hello: 'world' }); - expect(result.providerId).toBe('dummyProviderId'); + expect(result.connection).toBe(connectionMock); }); }); diff --git a/packages/backend/src/utils/inferenceUtils.ts b/packages/backend/src/utils/inferenceUtils.ts index bd7a0ee50..d9e39f36d 100644 --- a/packages/backend/src/utils/inferenceUtils.ts +++ b/packages/backend/src/utils/inferenceUtils.ts @@ -20,8 +20,6 @@ import { type ContainerProviderConnection, type ImageInfo, type ListImagesOptions, - provider, - type ProviderContainerConnection, type PullEvent, } from '@podman-desktop/api'; import type { CreationInferenceServerOptions, InferenceServerConfig } from '@shared/src/models/InferenceServerConfig'; @@ -31,33 +29,6 @@ import type { ModelInfo } from '@shared/src/models/IModelInfo'; export const LABEL_INFERENCE_SERVER: string = 'ai-lab-inference-server'; -/** - * Return container connection provider - */ -export function getProviderContainerConnection(providerId?: string): ProviderContainerConnection { - // Get started providers - const providers = provider - .getContainerConnections() - .filter(connection => connection.connection.status() === 'started'); - - if (providers.length === 0) throw new Error('no engine started could be find.'); - - let output: ProviderContainerConnection | undefined = undefined; - - // If we expect a specific engine - if (providerId !== undefined) { - output = providers.find(engine => engine.providerId === providerId); - } else { - // Have a preference for a podman engine - output = providers.find(engine => engine.connection.type === 'podman'); - if (output === undefined) { - output = providers[0]; - } - } - if (output === undefined) throw new Error('cannot find any started container provider.'); - return output; -} - /** * Given an image name, it will return the ImageInspectInfo corresponding. Will raise an error if not found. * @param connection @@ -107,7 +78,7 @@ export async function withDefaultConfiguration( image: options.image, labels: options.labels || {}, modelsInfo: options.modelsInfo, - providerId: options.providerId, + connection: options.connection, inferenceProvider: options.inferenceProvider, gpuLayers: options.gpuLayers !== undefined ? options.gpuLayers : -1, }; diff --git a/packages/backend/src/workers/provider/InferenceProvider.spec.ts b/packages/backend/src/workers/provider/InferenceProvider.spec.ts index d30dc39d5..884700c68 100644 --- a/packages/backend/src/workers/provider/InferenceProvider.spec.ts +++ b/packages/backend/src/workers/provider/InferenceProvider.spec.ts @@ -20,20 +20,14 @@ import { beforeEach, describe, expect, test, vi } from 'vitest'; import type { TaskRegistry } from '../../registries/TaskRegistry'; import { type BetterContainerCreateResult, InferenceProvider } from './InferenceProvider'; import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig'; -import type { - ContainerCreateOptions, - ContainerProviderConnection, - ImageInfo, - ProviderContainerConnection, -} from '@podman-desktop/api'; +import type { ContainerCreateOptions, ContainerProviderConnection, ImageInfo } from '@podman-desktop/api'; import { containerEngine } from '@podman-desktop/api'; -import { getImageInfo, getProviderContainerConnection } from '../../utils/inferenceUtils'; +import { getImageInfo } from '../../utils/inferenceUtils'; import type { TaskState } from '@shared/src/models/ITask'; import type { InferenceServer } from '@shared/src/models/IInference'; import { InferenceType } from '@shared/src/models/IInference'; vi.mock('../../utils/inferenceUtils', () => ({ - getProviderContainerConnection: vi.fn(), getImageInfo: vi.fn(), LABEL_INFERENCE_SERVER: 'ai-lab-inference-server', })); @@ -44,14 +38,6 @@ vi.mock('@podman-desktop/api', () => ({ }, })); -const DummyProviderContainerConnection: ProviderContainerConnection = { - providerId: 'dummy-provider-id', - connection: { - name: 'dummy-provider-connection', - type: 'podman', - } as unknown as ContainerProviderConnection, -}; - const DummyImageInfo: ImageInfo = { Id: 'dummy-image-id', engineId: 'dummy-engine-id', @@ -62,6 +48,11 @@ const taskRegistry: TaskRegistry = { updateTask: vi.fn(), } as unknown as TaskRegistry; +const connectionMock: ContainerProviderConnection = { + name: 'Dummy Connection', + type: 'podman', +} as unknown as ContainerProviderConnection; + class TestInferenceProvider extends InferenceProvider { constructor() { super(taskRegistry, InferenceType.NONE, 'test-inference-provider'); @@ -71,8 +62,8 @@ class TestInferenceProvider extends InferenceProvider { throw new Error('not implemented'); } - publicPullImage(providerId: string | undefined, image: string, labels: { [id: string]: string }) { - return super.pullImage(providerId, image, labels); + publicPullImage(connection: ContainerProviderConnection, image: string, labels: { [id: string]: string }) { + return super.pullImage(connection, image, labels); } async publicCreateContainer( @@ -96,7 +87,6 @@ class TestInferenceProvider extends InferenceProvider { beforeEach(() => { vi.resetAllMocks(); - vi.mocked(getProviderContainerConnection).mockReturnValue(DummyProviderContainerConnection); vi.mocked(getImageInfo).mockResolvedValue(DummyImageInfo); vi.mocked(taskRegistry.createTask).mockImplementation( (name: string, state: TaskState, labels: { [id: string]: string } = {}) => ({ @@ -115,7 +105,7 @@ beforeEach(() => { describe('pullImage', () => { test('should create a task and mark as success on completion', async () => { const provider = new TestInferenceProvider(); - await provider.publicPullImage('dummy-provider-id', 'dummy-image', { + await provider.publicPullImage(connectionMock, 'dummy-image', { key: 'value', }); @@ -138,7 +128,7 @@ describe('pullImage', () => { vi.mocked(getImageInfo).mockRejectedValue(new Error('dummy test error')); await expect( - provider.publicPullImage('dummy-provider-id', 'dummy-image', { + provider.publicPullImage(connectionMock, 'dummy-image', { key: 'value', }), ).rejects.toThrowError('dummy test error'); diff --git a/packages/backend/src/workers/provider/InferenceProvider.ts b/packages/backend/src/workers/provider/InferenceProvider.ts index 3c9824865..96cba4947 100644 --- a/packages/backend/src/workers/provider/InferenceProvider.ts +++ b/packages/backend/src/workers/provider/InferenceProvider.ts @@ -18,6 +18,7 @@ import type { ContainerCreateOptions, ContainerCreateResult, + ContainerProviderConnection, Disposable, ImageInfo, PullEvent, @@ -26,7 +27,7 @@ import { containerEngine } from '@podman-desktop/api'; import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig'; import type { IWorker } from '../IWorker'; import type { TaskRegistry } from '../../registries/TaskRegistry'; -import { getImageInfo, getProviderContainerConnection } from '../../utils/inferenceUtils'; +import { getImageInfo } from '../../utils/inferenceUtils'; import type { InferenceServer, InferenceType } from '@shared/src/models/IInference'; export type BetterContainerCreateResult = ContainerCreateResult & { engineId: string }; @@ -77,24 +78,21 @@ export abstract class InferenceProvider implements IWorker { // Creating a task to follow pulling progress const pullingTask = this.taskRegistry.createTask(`Pulling ${image}.`, 'loading', labels); - // Get the provider - const provider = getProviderContainerConnection(providerId); - // get the default image info for this provider - return getImageInfo(provider.connection, image, (_event: PullEvent) => {}) + return getImageInfo(connection, image, (_event: PullEvent) => {}) .catch((err: unknown) => { pullingTask.state = 'error'; pullingTask.progress = undefined; diff --git a/packages/backend/src/workers/provider/LlamaCppPython.spec.ts b/packages/backend/src/workers/provider/LlamaCppPython.spec.ts index 57f09896e..b61281556 100644 --- a/packages/backend/src/workers/provider/LlamaCppPython.spec.ts +++ b/packages/backend/src/workers/provider/LlamaCppPython.spec.ts @@ -20,8 +20,8 @@ import { beforeEach, describe, expect, test, vi } from 'vitest'; import type { TaskRegistry } from '../../registries/TaskRegistry'; import { LlamaCppPython, SECOND } from './LlamaCppPython'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; -import { getImageInfo, getProviderContainerConnection, LABEL_INFERENCE_SERVER } from '../../utils/inferenceUtils'; -import type { ContainerProviderConnection, ImageInfo, ProviderContainerConnection } from '@podman-desktop/api'; +import { getImageInfo, LABEL_INFERENCE_SERVER } from '../../utils/inferenceUtils'; +import type { ContainerProviderConnection, ImageInfo } from '@podman-desktop/api'; import { containerEngine } from '@podman-desktop/api'; import type { GPUManager } from '../../managers/GPUManager'; import type { PodmanConnection } from '../../managers/podmanConnection'; @@ -31,6 +31,7 @@ import { GPUVendor } from '@shared/src/models/IGPUInfo'; import type { InferenceServer } from '@shared/src/models/IInference'; import { InferenceType } from '@shared/src/models/IInference'; import { llamacpp } from '../../assets/inference-images.json'; +import type { ContainerProviderConnectionInfo } from '@shared/src/models/IContainerConnectionInfo'; vi.mock('@podman-desktop/api', () => ({ containerEngine: { @@ -64,12 +65,14 @@ const DummyModel: ModelInfo = { description: 'dummy-desc', }; -const DummyProviderContainerConnection: ProviderContainerConnection = { - providerId: 'dummy-provider-id', - connection: { - name: 'dummy-provider-connection', - type: 'podman', - } as unknown as ContainerProviderConnection, +const dummyConnection: ContainerProviderConnection = { + name: 'dummy-provider-connection', + type: 'podman', + vmType: VMType.WSL, + status: () => 'started', + endpoint: { + socketPath: 'dummy-socket', + }, }; const DummyImageInfo: ImageInfo = { @@ -78,7 +81,8 @@ const DummyImageInfo: ImageInfo = { } as unknown as ImageInfo; const podmanConnection: PodmanConnection = { - getVMType: vi.fn(), + findRunningContainerProviderConnection: vi.fn(), + getContainerProviderConnection: vi.fn(), } as unknown as PodmanConnection; const configurationRegistry: ConfigurationRegistry = { @@ -92,8 +96,8 @@ beforeEach(() => { experimentalGPU: false, modelsPath: 'model-path', }); - vi.mocked(podmanConnection.getVMType).mockResolvedValue(VMType.WSL); - vi.mocked(getProviderContainerConnection).mockReturnValue(DummyProviderContainerConnection); + vi.mocked(podmanConnection.findRunningContainerProviderConnection).mockReturnValue(dummyConnection); + vi.mocked(podmanConnection.getContainerProviderConnection).mockReturnValue(dummyConnection); vi.mocked(getImageInfo).mockResolvedValue(DummyImageInfo); vi.mocked(taskRegistry.createTask).mockReturnValue({ id: 'dummy-task-id', name: '', labels: {}, state: 'loading' }); vi.mocked(containerEngine.createContainer).mockResolvedValue({ @@ -116,15 +120,11 @@ describe('perform', () => { image: undefined, labels: {}, modelsInfo: [DummyModel], - providerId: undefined, + connection: undefined, }); - expect(getProviderContainerConnection).toHaveBeenCalledWith(undefined); - expect(getImageInfo).toHaveBeenCalledWith( - DummyProviderContainerConnection.connection, - llamacpp.default, - expect.anything(), - ); + expect(podmanConnection.findRunningContainerProviderConnection).toHaveBeenCalled(); + expect(getImageInfo).toHaveBeenCalledWith(dummyConnection, llamacpp.default, expect.anything()); }); test('config without models should throw an error', async () => { @@ -136,7 +136,7 @@ describe('perform', () => { image: undefined, labels: {}, modelsInfo: [], - providerId: undefined, + connection: undefined, }), ).rejects.toThrowError('Need at least one model info to start an inference server.'); }); @@ -154,7 +154,7 @@ describe('perform', () => { id: 'invalid', } as unknown as ModelInfo, ], - providerId: undefined, + connection: undefined, }), ).rejects.toThrowError('The model info file provided is undefined'); }); @@ -167,7 +167,7 @@ describe('perform', () => { image: undefined, labels: {}, modelsInfo: [DummyModel], - providerId: undefined, + connection: undefined, }); expect(server).toStrictEqual({ @@ -247,7 +247,7 @@ describe('perform', () => { }, }, ], - providerId: undefined, + connection: undefined, }); expect(containerEngine.createContainer).toHaveBeenCalledWith(DummyImageInfo.engineId, { @@ -287,7 +287,7 @@ describe('perform', () => { image: undefined, labels: {}, modelsInfo: [DummyModel], - providerId: undefined, + connection: undefined, }); expect(gpuManager.collectGPUs).toHaveBeenCalled(); @@ -316,7 +316,7 @@ describe('perform', () => { image: undefined, labels: {}, modelsInfo: [DummyModel], - providerId: undefined, + connection: undefined, }); expect(gpuManager.collectGPUs).toHaveBeenCalled(); @@ -324,6 +324,10 @@ describe('perform', () => { }); test('LIBKRUN vmtype should uses llamacpp.vulkan image', async () => { + vi.mocked(podmanConnection.findRunningContainerProviderConnection).mockReturnValue({ + ...dummyConnection, + vmType: VMType.LIBKRUN, + }); vi.mocked(configurationRegistry.getExtensionConfiguration).mockReturnValue({ experimentalGPU: true, modelsPath: '', @@ -337,18 +341,40 @@ describe('perform', () => { }, ]); - vi.mocked(podmanConnection.getVMType).mockResolvedValue(VMType.LIBKRUN); const provider = new LlamaCppPython(taskRegistry, podmanConnection, gpuManager, configurationRegistry); const server = await provider.perform({ port: 8000, image: undefined, labels: {}, modelsInfo: [DummyModel], - providerId: undefined, + connection: undefined, }); expect(getImageInfo).toHaveBeenCalledWith(expect.anything(), llamacpp.vulkan, expect.any(Function)); expect(gpuManager.collectGPUs).toHaveBeenCalled(); expect('gpu' in server.labels).toBeTruthy(); }); + + test('provided connection should be used for pulling the image', async () => { + const connection: ContainerProviderConnectionInfo = { + name: 'Dummy Podman', + type: 'podman', + vmType: VMType.WSL, + status: 'started', + providerId: 'fakeProviderId', + }; + const provider = new LlamaCppPython(taskRegistry, podmanConnection, gpuManager, configurationRegistry); + + await provider.perform({ + port: 8000, + image: undefined, + labels: {}, + modelsInfo: [DummyModel], + connection: connection, + }); + + expect(podmanConnection.getContainerProviderConnection).toHaveBeenCalledWith(connection); + expect(podmanConnection.findRunningContainerProviderConnection).not.toHaveBeenCalled(); + expect(getImageInfo).toHaveBeenCalledWith(dummyConnection, llamacpp.default, expect.anything()); + }); }); diff --git a/packages/backend/src/workers/provider/LlamaCppPython.ts b/packages/backend/src/workers/provider/LlamaCppPython.ts index 00d3594f8..183a6c621 100644 --- a/packages/backend/src/workers/provider/LlamaCppPython.ts +++ b/packages/backend/src/workers/provider/LlamaCppPython.ts @@ -15,7 +15,13 @@ * * SPDX-License-Identifier: Apache-2.0 ***********************************************************************/ -import type { ContainerCreateOptions, DeviceRequest, ImageInfo, MountConfig } from '@podman-desktop/api'; +import type { + ContainerCreateOptions, + ContainerProviderConnection, + DeviceRequest, + ImageInfo, + MountConfig, +} from '@podman-desktop/api'; import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig'; import { InferenceProvider } from './InferenceProvider'; import { getModelPropertiesForEnvironment } from '../../utils/modelsUtils'; @@ -195,11 +201,20 @@ export class LlamaCppPython extends InferenceProvider { gpu = gpus[0]; } - const vmType = await this.podmanConnection.getVMType(); + let connection: ContainerProviderConnection | undefined = undefined; + if (config.connection) { + connection = this.podmanConnection.getContainerProviderConnection(config.connection); + } else { + connection = this.podmanConnection.findRunningContainerProviderConnection(); + } + + if (!connection) throw new Error('no running connection could be found'); + + const vmType: VMType = (connection.vmType ?? VMType.UNKNOWN) as VMType; // pull the image const imageInfo: ImageInfo = await this.pullImage( - config.providerId, + connection, config.image ?? this.getLlamaCppInferenceImage(vmType, gpu), config.labels, ); @@ -208,7 +223,7 @@ export class LlamaCppPython extends InferenceProvider { const containerCreateOptions: ContainerCreateOptions = await this.getContainerCreateOptions( config, imageInfo, - vmType, + connection.vmType as VMType, gpu, ); diff --git a/packages/backend/src/workers/provider/WhisperCpp.spec.ts b/packages/backend/src/workers/provider/WhisperCpp.spec.ts index 6b67a3870..e4c2c666f 100644 --- a/packages/backend/src/workers/provider/WhisperCpp.spec.ts +++ b/packages/backend/src/workers/provider/WhisperCpp.spec.ts @@ -21,9 +21,12 @@ import type { TaskRegistry } from '../../registries/TaskRegistry'; import { WhisperCpp } from './WhisperCpp'; import type { InferenceServer } from '@shared/src/models/IInference'; import { InferenceType } from '@shared/src/models/IInference'; -import type { ContainerProviderConnection, ProviderContainerConnection, ImageInfo } from '@podman-desktop/api'; +import type { ContainerProviderConnection, ImageInfo } from '@podman-desktop/api'; import { containerEngine } from '@podman-desktop/api'; -import { getImageInfo, getProviderContainerConnection } from '../../utils/inferenceUtils'; +import { getImageInfo } from '../../utils/inferenceUtils'; +import type { PodmanConnection } from '../../managers/podmanConnection'; +import type { ContainerProviderConnectionInfo } from '@shared/src/models/IContainerConnectionInfo'; +import { VMType } from '@shared/src/models/IPodman'; vi.mock('@podman-desktop/api', () => ({ containerEngine: { @@ -37,13 +40,10 @@ vi.mock('../../utils/inferenceUtils', () => ({ LABEL_INFERENCE_SERVER: 'ai-lab-inference-server', })); -const DummyProviderContainerConnection: ProviderContainerConnection = { - providerId: 'dummy-provider-id', - connection: { - name: 'dummy-provider-connection', - type: 'podman', - } as unknown as ContainerProviderConnection, -}; +const connectionMock: ContainerProviderConnection = { + name: 'dummy-provider-connection', + type: 'podman', +} as unknown as ContainerProviderConnection; const DummyImageInfo: ImageInfo = { Id: 'dummy-image-id', @@ -55,8 +55,16 @@ const taskRegistry: TaskRegistry = { updateTask: vi.fn(), } as unknown as TaskRegistry; +const podmanConnection: PodmanConnection = { + findRunningContainerProviderConnection: vi.fn(), + getContainerProviderConnection: vi.fn(), +} as unknown as PodmanConnection; + beforeEach(() => { - vi.mocked(getProviderContainerConnection).mockReturnValue(DummyProviderContainerConnection); + vi.resetAllMocks(); + + vi.mocked(podmanConnection.findRunningContainerProviderConnection).mockReturnValue(connectionMock); + vi.mocked(podmanConnection.getContainerProviderConnection).mockReturnValue(connectionMock); vi.mocked(taskRegistry.createTask).mockReturnValue({ id: 'dummy-task-id', name: '', labels: {}, state: 'loading' }); vi.mocked(getImageInfo).mockResolvedValue(DummyImageInfo); @@ -67,7 +75,7 @@ beforeEach(() => { }); test('provider requires at least one model', async () => { - const provider = new WhisperCpp(taskRegistry); + const provider = new WhisperCpp(taskRegistry, podmanConnection); await expect(() => { return provider.perform({ @@ -79,7 +87,7 @@ test('provider requires at least one model', async () => { }); test('provider requires a downloaded model', async () => { - const provider = new WhisperCpp(taskRegistry); + const provider = new WhisperCpp(taskRegistry, podmanConnection); await expect(() => { return provider.perform({ @@ -98,7 +106,7 @@ test('provider requires a downloaded model', async () => { }); test('provider requires a model with backend type Whisper', async () => { - const provider = new WhisperCpp(taskRegistry); + const provider = new WhisperCpp(taskRegistry, podmanConnection); await expect(() => { return provider.perform({ @@ -124,7 +132,7 @@ test('provider requires a model with backend type Whisper', async () => { }); test('custom image in inference server config should overwrite default', async () => { - const provider = new WhisperCpp(taskRegistry); + const provider = new WhisperCpp(taskRegistry, podmanConnection); const model = { id: 'whisper-cpp', @@ -147,15 +155,11 @@ test('custom image in inference server config should overwrite default', async ( modelsInfo: [model], }); - expect(getImageInfo).toHaveBeenCalledWith( - DummyProviderContainerConnection.connection, - 'localhost/whisper-cpp:custom', - expect.any(Function), - ); + expect(getImageInfo).toHaveBeenCalledWith(connectionMock, 'localhost/whisper-cpp:custom', expect.any(Function)); }); test('provider should propagate labels', async () => { - const provider = new WhisperCpp(taskRegistry); + const provider = new WhisperCpp(taskRegistry, podmanConnection); const model = { id: 'whisper-cpp', @@ -195,3 +199,40 @@ test('provider should propagate labels', async () => { type: InferenceType.WHISPER_CPP, }); }); + +test('provided connection should be used for pulling the image', async () => { + const connection: ContainerProviderConnectionInfo = { + name: 'Dummy Podman', + type: 'podman', + vmType: VMType.WSL, + status: 'started', + providerId: 'fakeProviderId', + }; + const provider = new WhisperCpp(taskRegistry, podmanConnection); + + const model = { + id: 'whisper-cpp', + name: 'Whisper', + properties: {}, + description: 'whisper desc', + file: { + file: 'random-file', + path: 'path-to-file', + }, + backend: InferenceType.WHISPER_CPP, + }; + + await provider.perform({ + connection: connection, + port: 8888, + labels: { + hello: 'world', + }, + image: 'localhost/whisper-cpp:custom', + modelsInfo: [model], + }); + + expect(getImageInfo).toHaveBeenCalledWith(connectionMock, 'localhost/whisper-cpp:custom', expect.any(Function)); + expect(podmanConnection.getContainerProviderConnection).toHaveBeenCalledWith(connection); + expect(podmanConnection.findRunningContainerProviderConnection).not.toHaveBeenCalled(); +}); diff --git a/packages/backend/src/workers/provider/WhisperCpp.ts b/packages/backend/src/workers/provider/WhisperCpp.ts index 104df6c9f..26f5c7936 100644 --- a/packages/backend/src/workers/provider/WhisperCpp.ts +++ b/packages/backend/src/workers/provider/WhisperCpp.ts @@ -21,12 +21,16 @@ import type { InferenceServer } from '@shared/src/models/IInference'; import { InferenceType } from '@shared/src/models/IInference'; import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig'; import { LABEL_INFERENCE_SERVER } from '../../utils/inferenceUtils'; -import type { MountConfig } from '@podman-desktop/api'; +import type { ContainerProviderConnection, MountConfig } from '@podman-desktop/api'; import { DISABLE_SELINUX_LABEL_SECURITY_OPTION } from '../../utils/utils'; import { whispercpp } from '../../assets/inference-images.json'; +import type { PodmanConnection } from '../../managers/podmanConnection'; export class WhisperCpp extends InferenceProvider { - constructor(taskRegistry: TaskRegistry) { + constructor( + taskRegistry: TaskRegistry, + private podmanConnection: PodmanConnection, + ) { super(taskRegistry, InferenceType.WHISPER_CPP, 'Whisper-cpp'); } @@ -54,7 +58,16 @@ export class WhisperCpp extends InferenceProvider { [LABEL_INFERENCE_SERVER]: JSON.stringify(config.modelsInfo.map(model => model.id)), }; - const imageInfo = await this.pullImage(config.providerId, config.image ?? whispercpp.default, labels); + let connection: ContainerProviderConnection | undefined = undefined; + if (config.connection) { + connection = this.podmanConnection.getContainerProviderConnection(config.connection); + } else { + connection = this.podmanConnection.findRunningContainerProviderConnection(); + } + + if (!connection) throw new Error('no running connection could be found'); + + const imageInfo = await this.pullImage(connection, config.image ?? whispercpp.default, labels); const envs: string[] = [`MODEL_PATH=/models/${modelInfo.file.file}`, 'HOST=0.0.0.0', 'PORT=8000']; const mounts: MountConfig = [ diff --git a/packages/shared/src/models/InferenceServerConfig.ts b/packages/shared/src/models/InferenceServerConfig.ts index 79ee34395..f0108a8ba 100644 --- a/packages/shared/src/models/InferenceServerConfig.ts +++ b/packages/shared/src/models/InferenceServerConfig.ts @@ -16,6 +16,7 @@ * SPDX-License-Identifier: Apache-2.0 ***********************************************************************/ import type { ModelInfo } from './IModelInfo'; +import type { ContainerProviderConnectionInfo } from './IContainerConnectionInfo'; export type CreationInferenceServerOptions = Partial & { modelsInfo: ModelInfo[] }; @@ -25,9 +26,9 @@ export interface InferenceServerConfig { */ port: number; /** - * The identifier of the container provider to use + * The connection info to use */ - providerId?: string; + connection?: ContainerProviderConnectionInfo; /** * The name of the inference provider to use */