diff --git a/.changeset/many-chefs-appear.md b/.changeset/many-chefs-appear.md new file mode 100644 index 00000000..ddc619b4 --- /dev/null +++ b/.changeset/many-chefs-appear.md @@ -0,0 +1,5 @@ +--- +"@google/generative-ai": patch +--- + +Fixed bugs where `RequestOptions`, `generationConfig`, and `safetySettings` were not passed from the model down to some methods. diff --git a/packages/main/src/models/generative-model.test.ts b/packages/main/src/models/generative-model.test.ts index 254d3712..42603a66 100644 --- a/packages/main/src/models/generative-model.test.ts +++ b/packages/main/src/models/generative-model.test.ts @@ -17,7 +17,11 @@ import { expect, use } from "chai"; import { GenerativeModel } from "./generative-model"; import * as sinonChai from "sinon-chai"; -import { FunctionCallingMode } from "../../types"; +import { + FunctionCallingMode, + HarmBlockThreshold, + HarmCategory, +} from "../../types"; import { getMockResponse } from "../../test-utils/mock-response"; import { match, restore, stub } from "sinon"; import * as request from "../requests/request"; @@ -42,12 +46,29 @@ describe("GenerativeModel", () => { expect(genModel.model).to.equal("tunedModels/my-model"); }); it("passes params through to generateContent", async () => { - const genModel = new GenerativeModel("apiKey", { - model: "my-model", - tools: [{ functionDeclarations: [{ name: "myfunc" }] }], - toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, - systemInstruction: { role: "system", parts: [{ text: "be friendly" }] }, - }); + const genModel = new GenerativeModel( + "apiKey", + { + model: "my-model", + generationConfig: { temperature: 0 }, + safetySettings: [ + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + }, + ], + tools: [{ functionDeclarations: [{ name: "myfunc" }] }], + toolConfig: { + functionCallingConfig: { mode: FunctionCallingMode.NONE }, + }, + systemInstruction: { role: "system", parts: [{ text: "be friendly" }] }, + }, + { + apiVersion: "v6", + }, + ); + expect(genModel.generationConfig?.temperature).to.equal(0); + expect(genModel.safetySettings?.length).to.equal(1); expect(genModel.tools?.length).to.equal(1); expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal( FunctionCallingMode.NONE, @@ -69,16 +90,27 @@ describe("GenerativeModel", () => { return ( value.includes("myfunc") && value.includes(FunctionCallingMode.NONE) && - value.includes("be friendly") + value.includes("be friendly") && + value.includes("temperature") && + value.includes(HarmBlockThreshold.BLOCK_LOW_AND_ABOVE) ); }), - {}, + match((value) => { + return value.apiVersion === "v6"; + }), ); restore(); }); it("generateContent overrides model values", async () => { const genModel = new GenerativeModel("apiKey", { model: "my-model", + generationConfig: { temperature: 0 }, + safetySettings: [ + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + }, + ], tools: [{ functionDeclarations: [{ name: "myfunc" }] }], toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, systemInstruction: { role: "system", parts: [{ text: "be friendly" }] }, @@ -95,6 +127,17 @@ describe("GenerativeModel", () => { mockResponse as Response, ); await genModel.generateContent({ + generationConfig: { topK: 1 }, + safetySettings: [ + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + }, + { + category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold: HarmBlockThreshold.BLOCK_NONE, + }, + ], contents: [{ role: "user", parts: [{ text: "hello" }] }], tools: [{ functionDeclarations: [{ name: "otherfunc" }] }], toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.AUTO } }, @@ -109,13 +152,45 @@ describe("GenerativeModel", () => { return ( value.includes("otherfunc") && value.includes(FunctionCallingMode.AUTO) && - value.includes("be formal") + value.includes("be formal") && + value.includes("topK") && + value.includes(HarmCategory.HARM_CATEGORY_HARASSMENT) ); }), {}, ); restore(); }); + it("passes requestOptions through to countTokens", async () => { + const genModel = new GenerativeModel( + "apiKey", + { + model: "my-model", + }, + { + apiVersion: "v2000", + }, + ); + const mockResponse = getMockResponse( + "unary-success-basic-reply-short.json", + ); + const makeRequestStub = stub(request, "makeRequest").resolves( + mockResponse as Response, + ); + await genModel.countTokens("hello"); + console.log(makeRequestStub.args[0]); + expect(makeRequestStub).to.be.calledWith( + "models/my-model", + request.Task.COUNT_TOKENS, + match.any, + false, + match.any, + match((value) => { + return value.apiVersion === "v2000"; + }), + ); + restore(); + }); it("passes params through to chat.sendMessage", async () => { const genModel = new GenerativeModel("apiKey", { model: "my-model", diff --git a/packages/main/src/models/generative-model.ts b/packages/main/src/models/generative-model.ts index d2e1b2e3..e3d51fcb 100644 --- a/packages/main/src/models/generative-model.ts +++ b/packages/main/src/models/generative-model.ts @@ -137,6 +137,8 @@ export class GenerativeModel { this.apiKey, this.model, { + generationConfig: this.generationConfig, + safetySettings: this.safetySettings, tools: this.tools, toolConfig: this.toolConfig, systemInstruction: this.systemInstruction, @@ -153,7 +155,12 @@ export class GenerativeModel { request: CountTokensRequest | string | Array, ): Promise { const formattedParams = formatGenerateContentInput(request); - return countTokens(this.apiKey, this.model, formattedParams); + return countTokens( + this.apiKey, + this.model, + formattedParams, + this.requestOptions, + ); } /** @@ -163,7 +170,12 @@ export class GenerativeModel { request: EmbedContentRequest | string | Array, ): Promise { const formattedParams = formatEmbedContentInput(request); - return embedContent(this.apiKey, this.model, formattedParams); + return embedContent( + this.apiKey, + this.model, + formattedParams, + this.requestOptions, + ); } /**