Skip to content

Commit

Permalink
Fix a couple bugs where not all params are passed through to the requ…
Browse files Browse the repository at this point in the history
…est (#103)
  • Loading branch information
hsubox76 committed Apr 16, 2024
1 parent 9fe865d commit 6ef8cee
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 12 deletions.
5 changes: 5 additions & 0 deletions .changeset/many-chefs-appear.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@google/generative-ai": patch
---

Fixed bugs where `RequestOptions`, `generationConfig`, and `safetySettings` were not passed from the model down to some methods.
95 changes: 85 additions & 10 deletions packages/main/src/models/generative-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
import { expect, use } from "chai";
import { GenerativeModel } from "./generative-model";
import * as sinonChai from "sinon-chai";
import { FunctionCallingMode } from "../../types";
import {
FunctionCallingMode,
HarmBlockThreshold,
HarmCategory,
} from "../../types";
import { getMockResponse } from "../../test-utils/mock-response";
import { match, restore, stub } from "sinon";
import * as request from "../requests/request";
Expand All @@ -42,12 +46,29 @@ describe("GenerativeModel", () => {
expect(genModel.model).to.equal("tunedModels/my-model");
});
it("passes params through to generateContent", async () => {
const genModel = new GenerativeModel("apiKey", {
model: "my-model",
tools: [{ functionDeclarations: [{ name: "myfunc" }] }],
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
systemInstruction: { role: "system", parts: [{ text: "be friendly" }] },
});
const genModel = new GenerativeModel(
"apiKey",
{
model: "my-model",
generationConfig: { temperature: 0 },
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
},
],
tools: [{ functionDeclarations: [{ name: "myfunc" }] }],
toolConfig: {
functionCallingConfig: { mode: FunctionCallingMode.NONE },
},
systemInstruction: { role: "system", parts: [{ text: "be friendly" }] },
},
{
apiVersion: "v6",
},
);
expect(genModel.generationConfig?.temperature).to.equal(0);
expect(genModel.safetySettings?.length).to.equal(1);
expect(genModel.tools?.length).to.equal(1);
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
FunctionCallingMode.NONE,
Expand All @@ -69,16 +90,27 @@ describe("GenerativeModel", () => {
return (
value.includes("myfunc") &&
value.includes(FunctionCallingMode.NONE) &&
value.includes("be friendly")
value.includes("be friendly") &&
value.includes("temperature") &&
value.includes(HarmBlockThreshold.BLOCK_LOW_AND_ABOVE)
);
}),
{},
match((value) => {
return value.apiVersion === "v6";
}),
);
restore();
});
it("generateContent overrides model values", async () => {
const genModel = new GenerativeModel("apiKey", {
model: "my-model",
generationConfig: { temperature: 0 },
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
},
],
tools: [{ functionDeclarations: [{ name: "myfunc" }] }],
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
systemInstruction: { role: "system", parts: [{ text: "be friendly" }] },
Expand All @@ -95,6 +127,17 @@ describe("GenerativeModel", () => {
mockResponse as Response,
);
await genModel.generateContent({
generationConfig: { topK: 1 },
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
},
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: HarmBlockThreshold.BLOCK_NONE,
},
],
contents: [{ role: "user", parts: [{ text: "hello" }] }],
tools: [{ functionDeclarations: [{ name: "otherfunc" }] }],
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.AUTO } },
Expand All @@ -109,13 +152,45 @@ describe("GenerativeModel", () => {
return (
value.includes("otherfunc") &&
value.includes(FunctionCallingMode.AUTO) &&
value.includes("be formal")
value.includes("be formal") &&
value.includes("topK") &&
value.includes(HarmCategory.HARM_CATEGORY_HARASSMENT)
);
}),
{},
);
restore();
});
it("passes requestOptions through to countTokens", async () => {
const genModel = new GenerativeModel(
"apiKey",
{
model: "my-model",
},
{
apiVersion: "v2000",
},
);
const mockResponse = getMockResponse(
"unary-success-basic-reply-short.json",
);
const makeRequestStub = stub(request, "makeRequest").resolves(
mockResponse as Response,
);
await genModel.countTokens("hello");
console.log(makeRequestStub.args[0]);
expect(makeRequestStub).to.be.calledWith(
"models/my-model",
request.Task.COUNT_TOKENS,
match.any,
false,
match.any,
match((value) => {
return value.apiVersion === "v2000";
}),
);
restore();
});
it("passes params through to chat.sendMessage", async () => {
const genModel = new GenerativeModel("apiKey", {
model: "my-model",
Expand Down
16 changes: 14 additions & 2 deletions packages/main/src/models/generative-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ export class GenerativeModel {
this.apiKey,
this.model,
{
generationConfig: this.generationConfig,
safetySettings: this.safetySettings,
tools: this.tools,
toolConfig: this.toolConfig,
systemInstruction: this.systemInstruction,
Expand All @@ -153,7 +155,12 @@ export class GenerativeModel {
request: CountTokensRequest | string | Array<string | Part>,
): Promise<CountTokensResponse> {
const formattedParams = formatGenerateContentInput(request);
return countTokens(this.apiKey, this.model, formattedParams);
return countTokens(
this.apiKey,
this.model,
formattedParams,
this.requestOptions,
);
}

/**
Expand All @@ -163,7 +170,12 @@ export class GenerativeModel {
request: EmbedContentRequest | string | Array<string | Part>,
): Promise<EmbedContentResponse> {
const formattedParams = formatEmbedContentInput(request);
return embedContent(this.apiKey, this.model, formattedParams);
return embedContent(
this.apiKey,
this.model,
formattedParams,
this.requestOptions,
);
}

/**
Expand Down

0 comments on commit 6ef8cee

Please sign in to comment.