Skip to content

Commit

Permalink
Allow user to provide systemInstruction as string or Part (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsubox76 committed Apr 29, 2024
1 parent 111e970 commit ca62400
Show file tree
Hide file tree
Showing 12 changed files with 273 additions and 13 deletions.
5 changes: 5 additions & 0 deletions .changeset/perfect-hotels-protect.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@google/generative-ai": minor
---

Allow text-only systemInstruction as well as Part and Content.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export interface GenerateContentRequest extends BaseParams
| Property | Modifiers | Type | Description |
| --- | --- | --- | --- |
| [contents](./generative-ai.generatecontentrequest.contents.md) | | [Content](./generative-ai.content.md)<!-- -->\[\] | |
| [systemInstruction?](./generative-ai.generatecontentrequest.systeminstruction.md) | | [Content](./generative-ai.content.md) | _(Optional)_ |
| [systemInstruction?](./generative-ai.generatecontentrequest.systeminstruction.md) | | string \| [Part](./generative-ai.part.md) \| [Content](./generative-ai.content.md) | _(Optional)_ |
| [toolConfig?](./generative-ai.generatecontentrequest.toolconfig.md) | | [ToolConfig](./generative-ai.toolconfig.md) | _(Optional)_ |
| [tools?](./generative-ai.generatecontentrequest.tools.md) | | [Tool](./generative-ai.tool.md)<!-- -->\[\] | _(Optional)_ |
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
**Signature:**

```typescript
systemInstruction?: Content;
systemInstruction?: string | Part | Content;
```
2 changes: 1 addition & 1 deletion docs/reference/main/generative-ai.modelparams.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export interface ModelParams extends BaseParams
| Property | Modifiers | Type | Description |
| --- | --- | --- | --- |
| [model](./generative-ai.modelparams.model.md) | | string | |
| [systemInstruction?](./generative-ai.modelparams.systeminstruction.md) | | [Content](./generative-ai.content.md) | _(Optional)_ |
| [systemInstruction?](./generative-ai.modelparams.systeminstruction.md) | | string \| [Part](./generative-ai.part.md) \| [Content](./generative-ai.content.md) | _(Optional)_ |
| [toolConfig?](./generative-ai.modelparams.toolconfig.md) | | [ToolConfig](./generative-ai.toolconfig.md) | _(Optional)_ |
| [tools?](./generative-ai.modelparams.tools.md) | | [Tool](./generative-ai.tool.md)<!-- -->\[\] | _(Optional)_ |
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
**Signature:**

```typescript
systemInstruction?: Content;
systemInstruction?: string | Part | Content;
```
2 changes: 1 addition & 1 deletion docs/reference/main/generative-ai.startchatparams.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export interface StartChatParams extends BaseParams
| Property | Modifiers | Type | Description |
| --- | --- | --- | --- |
| [history?](./generative-ai.startchatparams.history.md) | | [Content](./generative-ai.content.md)<!-- -->\[\] | _(Optional)_ |
| [systemInstruction?](./generative-ai.startchatparams.systeminstruction.md) | | [Content](./generative-ai.content.md) | _(Optional)_ |
| [systemInstruction?](./generative-ai.startchatparams.systeminstruction.md) | | string \| [Part](./generative-ai.part.md) \| [Content](./generative-ai.content.md) | _(Optional)_ |
| [toolConfig?](./generative-ai.startchatparams.toolconfig.md) | | [ToolConfig](./generative-ai.toolconfig.md) | _(Optional)_ |
| [tools?](./generative-ai.startchatparams.tools.md) | | [Tool](./generative-ai.tool.md)<!-- -->\[\] | _(Optional)_ |
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
**Signature:**

```typescript
systemInstruction?: Content;
systemInstruction?: string | Part | Content;
```
50 changes: 50 additions & 0 deletions packages/main/src/models/generative-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,31 @@ describe("GenerativeModel", () => {
);
restore();
});
it("passes text-only systemInstruction through to generateContent", async () => {
const genModel = new GenerativeModel("apiKey", {
model: "my-model",
systemInstruction: "be friendly",
});
expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly");
const mockResponse = getMockResponse(
"unary-success-basic-reply-short.json",
);
const makeRequestStub = stub(request, "makeRequest").resolves(
mockResponse as Response,
);
await genModel.generateContent("hello");
expect(makeRequestStub).to.be.calledWith(
"models/my-model",
request.Task.GENERATE_CONTENT,
match.any,
false,
match((value: string) => {
return value.includes("be friendly");
}),
match.any,
);
restore();
});
it("generateContent overrides model values", async () => {
const genModel = new GenerativeModel("apiKey", {
model: "my-model",
Expand Down Expand Up @@ -226,6 +251,31 @@ describe("GenerativeModel", () => {
);
restore();
});
it("passes params through to chat.sendMessage", async () => {
const genModel = new GenerativeModel("apiKey", {
model: "my-model",
systemInstruction: { role: "system", parts: [{ text: "be friendly" }] },
});
expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly");
const mockResponse = getMockResponse(
"unary-success-basic-reply-short.json",
);
const makeRequestStub = stub(request, "makeRequest").resolves(
mockResponse as Response,
);
await genModel.startChat().sendMessage("hello");
expect(makeRequestStub).to.be.calledWith(
"models/my-model",
request.Task.GENERATE_CONTENT,
match.any,
false,
match((value: string) => {
return value.includes("be friendly");
}),
{},
);
restore();
});
it("startChat overrides model values", async () => {
const genModel = new GenerativeModel("apiKey", {
model: "my-model",
Expand Down
5 changes: 4 additions & 1 deletion packages/main/src/models/generative-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import { batchEmbedContents, embedContent } from "../methods/embed-content";
import {
formatEmbedContentInput,
formatGenerateContentInput,
formatSystemInstruction,
} from "../requests/request-helpers";

/**
Expand Down Expand Up @@ -76,7 +77,9 @@ export class GenerativeModel {
this.safetySettings = modelParams.safetySettings || [];
this.tools = modelParams.tools;
this.toolConfig = modelParams.toolConfig;
this.systemInstruction = modelParams.systemInstruction;
this.systemInstruction = formatSystemInstruction(
modelParams.systemInstruction,
);
this.requestOptions = requestOptions || {};
}

Expand Down
175 changes: 175 additions & 0 deletions packages/main/src/requests/request-helpers.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { expect, use } from "chai";
import * as sinonChai from "sinon-chai";
import { Content } from "../../types";
import { formatGenerateContentInput } from "./request-helpers";

use(sinonChai);

describe("request formatting methods", () => {
describe("formatGenerateContentInput", () => {
it("formats a text string into a request", () => {
const result = formatGenerateContentInput("some text content");
expect(result).to.deep.equal({
contents: [
{
role: "user",
parts: [{ text: "some text content" }],
},
],
});
});
it("formats an array of strings into a request", () => {
const result = formatGenerateContentInput(["txt1", "txt2"]);
expect(result).to.deep.equal({
contents: [
{
role: "user",
parts: [{ text: "txt1" }, { text: "txt2" }],
},
],
});
});
it("formats an array of Parts into a request", () => {
const result = formatGenerateContentInput([
{ text: "txt1" },
{ text: "txtB" },
]);
expect(result).to.deep.equal({
contents: [
{
role: "user",
parts: [{ text: "txt1" }, { text: "txtB" }],
},
],
});
});
it("formats a mixed array into a request", () => {
const result = formatGenerateContentInput(["txtA", { text: "txtB" }]);
expect(result).to.deep.equal({
contents: [
{
role: "user",
parts: [{ text: "txtA" }, { text: "txtB" }],
},
],
});
});
it("preserves other properties of request", () => {
const result = formatGenerateContentInput({
contents: [
{
role: "user",
parts: [{ text: "txtA" }],
},
],
generationConfig: { topK: 100 },
});
expect(result).to.deep.equal({
contents: [
{
role: "user",
parts: [{ text: "txtA" }],
},
],
generationConfig: { topK: 100 },
});
});
it("formats systemInstructions if provided as text", () => {
const result = formatGenerateContentInput({
contents: [
{
role: "user",
parts: [{ text: "txtA" }],
},
],
systemInstruction: "be excited",
});
expect(result).to.deep.equal({
contents: [
{
role: "user",
parts: [{ text: "txtA" }],
},
],
systemInstruction: { role: "system", parts: [{ text: "be excited" }] },
});
});
it("formats systemInstructions if provided as Part", () => {
const result = formatGenerateContentInput({
contents: [
{
role: "user",
parts: [{ text: "txtA" }],
},
],
systemInstruction: { text: "be excited" },
});
expect(result).to.deep.equal({
contents: [
{
role: "user",
parts: [{ text: "txtA" }],
},
],
systemInstruction: { role: "system", parts: [{ text: "be excited" }] },
});
});
it("formats systemInstructions if provided as Content (no role)", () => {
const result = formatGenerateContentInput({
contents: [
{
role: "user",
parts: [{ text: "txtA" }],
},
],
systemInstruction: { parts: [{ text: "be excited" }] } as Content,
});
expect(result).to.deep.equal({
contents: [
{
role: "user",
parts: [{ text: "txtA" }],
},
],
systemInstruction: { role: "system", parts: [{ text: "be excited" }] },
});
});
it("passes thru systemInstructions if provided as Content", () => {
const result = formatGenerateContentInput({
contents: [
{
role: "user",
parts: [{ text: "txtA" }],
},
],
systemInstruction: { role: "system", parts: [{ text: "be excited" }] },
});
expect(result).to.deep.equal({
contents: [
{
role: "user",
parts: [{ text: "txtA" }],
},
],
systemInstruction: { role: "system", parts: [{ text: "be excited" }] },
});
});
});
});
31 changes: 29 additions & 2 deletions packages/main/src/requests/request-helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,25 @@ import {
} from "../../types";
import { GoogleGenerativeAIError } from "../errors";

export function formatSystemInstruction(
input?: string | Part | Content,
): Content | undefined {
// null or undefined
if (input == null) {
return undefined;
} else if (typeof input === "string") {
return { role: "system", parts: [{ text: input }] } as Content;
} else if ((input as Part).text) {
return { role: "system", parts: [input as Part] };
} else if ((input as Content).parts) {
if (!(input as Content).role) {
return { role: "system", parts: (input as Content).parts };
} else {
return input as Content;
}
}
}

export function formatNewContent(
request: string | Array<string | Part>,
): Content {
Expand Down Expand Up @@ -88,12 +107,20 @@ function assignRoleToPartsAndValidateSendMessageRequest(
export function formatGenerateContentInput(
params: GenerateContentRequest | string | Array<string | Part>,
): GenerateContentRequest {
let formattedRequest: GenerateContentRequest;
if ((params as GenerateContentRequest).contents) {
return params as GenerateContentRequest;
formattedRequest = params as GenerateContentRequest;
} else {
// Array or string
const content = formatNewContent(params as string | Array<string | Part>);
return { contents: [content] };
formattedRequest = { contents: [content] };
}
if ((params as GenerateContentRequest).systemInstruction) {
formattedRequest.systemInstruction = formatSystemInstruction(
(params as GenerateContentRequest).systemInstruction,
);
}
return formattedRequest;
}

export function formatEmbedContentInput(
Expand Down
8 changes: 4 additions & 4 deletions packages/main/types/requests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

import { Content } from "./content";
import { Content, Part } from "./content";
import {
FunctionCallingMode,
HarmBlockThreshold,
Expand All @@ -40,7 +40,7 @@ export interface ModelParams extends BaseParams {
model: string;
tools?: Tool[];
toolConfig?: ToolConfig;
systemInstruction?: Content;
systemInstruction?: string | Part | Content;
}

/**
Expand All @@ -51,7 +51,7 @@ export interface GenerateContentRequest extends BaseParams {
contents: Content[];
tools?: Tool[];
toolConfig?: ToolConfig;
systemInstruction?: Content;
systemInstruction?: string | Part | Content;
}

/**
Expand Down Expand Up @@ -84,7 +84,7 @@ export interface StartChatParams extends BaseParams {
history?: Content[];
tools?: Tool[];
toolConfig?: ToolConfig;
systemInstruction?: Content;
systemInstruction?: string | Part | Content;
}

/**
Expand Down

0 comments on commit ca62400

Please sign in to comment.