Skip to content

Commit

Permalink
refactor prompt config (#78)
Browse files Browse the repository at this point in the history
* Refactor prompts to be fully db-configurable.
* Migrate additional envVars to db-config
- `MAX_SOURCE_LENGTH` -> `rag.generate.maxSourceLength`
- `MAX_CONTEXT_DOC_COUNT` -> `rag.generate.maxSourceDocCount`
- `MAX_CONTEXT_LENGTH` -> `rag.generate.maxContextLength`
* [admin] Fix prompt display order
  • Loading branch information
bdb-dd authored Jun 24, 2024
1 parent b06e516 commit 167cea8
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 76 deletions.
28 changes: 13 additions & 15 deletions apps/admin/src/components/RagPromptView.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,25 @@ const RagSourceView: React.FC<Params> = ({ message }) => {
},
};

const obj = message?.content?.prompts || {};

return (
<Box sx={{ flexWrap: "wrap" }}>
<React.Fragment>
<ErrorBoundary>
{" "}
<ul>
{Object.entries(message?.content?.prompts || {}).map(
([key, value], index) => (
<li key={key}>
<Box flexDirection="column">
<span>
Prompt #{index + 1}: {key}
</span>
<ReactMarkdown components={components}>
{value}
</ReactMarkdown>
</Box>
<hr />
</li>
),
)}
{["queryRelax", "generate"].map((key) => (
<li key={key}>
<Box flexDirection="column">
<span>Prompt: {key}</span>
<ReactMarkdown components={components}>
{obj[key]}
</ReactMarkdown>
</Box>
<hr />
</li>
))}
</ul>
</ErrorBoundary>
</React.Fragment>
Expand Down
87 changes: 61 additions & 26 deletions apps/slack-app/src/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import { stripCodeBlockLang, isNullOrEmpty } from '@digdir/assistant-lib';
import { botLog, BotLogEntry, updateReactions } from './utils/bot-log';
import OpenAI from 'openai';
import { isNumber } from 'remeda';
import { RagPipelineResult } from '@digdir/assistant-lib';
import { RagPipelineParams, RagPipelineResult } from '@digdir/assistant-lib';
import { markdownToBlocks } from '@bdb-dd/mack';

const expressReceiver = new ExpressReceiver({
Expand Down Expand Up @@ -107,26 +107,55 @@ app.message(async ({ message, say }) => {
true,
);

const queryRelaxCustom = await lookupConfig(slackApp, srcEvtContext, 'prompt.rag.queryRelax', '');
const promptRagQueryRelax = await lookupConfig(
slackApp,
srcEvtContext,
'rag.queryRelax.prompt',
'',
);
if (promptRagQueryRelax == '') {
console.error('promptRagQueryRelax is empty!');
throw new Error('promptRagQueryRelax is empty!');
}

const promptRagQueryRelax =
`You have access to a search API that returns relevant documentation.
const promptRagGenerate = await lookupConfig(slackApp, srcEvtContext, 'rag.generate.prompt', '');

Your task is to generate an array of up to 7 search queries that are relevant to this question.
Use a variation of related keywords and synonyms for the queries, trying to be as general as possible.
Include as many queries as you can think of, including and excluding terms.
For example, include queries like ['keyword_1 keyword_2', 'keyword_1', 'keyword_2'].
Be creative. The more queries you include, the more likely you are to find relevant results.
` + queryRelaxCustom;
if (promptRagGenerate == '') {
console.error('promptRagGenerate is empty!');
throw new Error('promptRagGenerate is empty!');
}

const promptRagGenerateCustom = await lookupConfig(
const maxSourceLength = await lookupConfig(
slackApp,
srcEvtContext,
'prompt.rag.generate',
'',
'rag.generate.maxSourceLength',
40000,
);
const maxSourceDocCount = await lookupConfig(
slackApp,
srcEvtContext,
'rag.generate.maxSourceDocCount',
10,
);
const maxContextLength = await lookupConfig(
slackApp,
srcEvtContext,
'rag.generate.maxContextLength',
90000,
);
const maxResponseTokenCount = await lookupConfig(
slackApp,
srcEvtContext,
'rag.generate.maxResponseTokenCount',
undefined,
);

const streamCallbackFreqSec = await lookupConfig(
slackApp,
srcEvtContext,
'rag.generate.streamCallbackFreqSec',
2.0,
);
const promptRagGenerate = qaTemplate(promptRagGenerateCustom || '');

if (envVar('LOG_LEVEL') == 'debug') {
console.log(`slackApp:\n${JSON.stringify(slackApp)}`);
Expand Down Expand Up @@ -302,17 +331,23 @@ app.message(async ({ message, say }) => {
let ragResponse: RagPipelineResult | null = null;

try {
ragResponse = await ragPipeline(
stage1Result.questionTranslatedToEnglish,
userInput,
stage1Result.userInputLanguageName,
promptRagQueryRelax || '',
promptRagGenerate || '',
docsCollectionName,
phrasesCollectionName,
originalMsgCallback,
translatedMsgCallback,
);
const ragParams: RagPipelineParams = {
translated_user_query: stage1Result.questionTranslatedToEnglish,
original_user_query: userInput,
user_query_language_name: stage1Result.userInputLanguageName,
promptRagQueryRelax: promptRagQueryRelax || '',
promptRagGenerate: promptRagGenerate || '',
docsCollectionName: docsCollectionName,
phrasesCollectionName: phrasesCollectionName,
stream_callback_msg1: originalMsgCallback,
stream_callback_msg2: translatedMsgCallback,
streamCallbackFreqSec,
maxResponseTokenCount,
maxSourceDocCount,
maxSourceLength,
maxContextLength,
};
ragResponse = await ragPipeline(ragParams);

ragResponse.durations.analyze = stage1Duration;

Expand Down
83 changes: 48 additions & 35 deletions packages/assistant-lib/src/docs/rag.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,25 @@ const RagPromptSchema = z.object({
});
export type RagPrompt = z.infer<typeof RagPromptSchema>;

const RagPipelineParamsSchema = z.object({
translated_user_query: z.string(),
original_user_query: z.string(),
user_query_language_name: z.string(),
promptRagQueryRelax: z.string(),
promptRagGenerate: z.string(),
maxSourceDocCount: z.number(),
maxContextLength: z.number(),
maxSourceLength: z.number(),
docsCollectionName: z.string(),
phrasesCollectionName: z.string(),
stream_callback_msg1: z.any().nullable(),
stream_callback_msg2: z.any().nullable(),
streamCallbackFreqSec: z.number().optional(),
maxResponseTokenCount: z.number().optional(),
});

export type RagPipelineParams = z.infer<typeof RagPipelineParamsSchema>;

const RagPipelineResultSchema = z.object({
original_user_query: z.string(),
english_user_query: z.string(),
Expand All @@ -56,15 +75,7 @@ const RagPipelineResultSchema = z.object({
export type RagPipelineResult = z.infer<typeof RagPipelineResultSchema>;

export async function ragPipeline(
translated_user_input: string,
original_user_input: string,
user_query_language_name: string,
promptRagQueryRelax: string,
promptRagGenerate: string,
docsCollectionName: string,
phrasesCollectionName: string,
stream_callback_msg1: any = null,
stream_callback_msg2: any = null,
params: RagPipelineParams,
): Promise<RagPipelineResult> {
const durations: any = {
total: 0,
Expand All @@ -77,15 +88,15 @@ export async function ragPipeline(
translation: 0,
};

if (envVar("MAX_CONTEXT_DOC_COUNT") == 0) {
throw new Error("MAX_CONTEXT_DOC_COUNT is set to 0");
if (params.maxSourceDocCount == 0) {
throw new Error("maxSourceDocCount is set to 0");
}
const total_start = performance.now();
var start = total_start;

const extract_search_queries = await queryRelaxation(
translated_user_input,
promptRagQueryRelax,
params.translated_user_query,
params.promptRagQueryRelax,
);
durations.generate_searches = round(lapTimer(total_start));

Expand All @@ -97,7 +108,7 @@ export async function ragPipeline(
}
start = performance.now();
const search_phrase_hits = await lookupSearchPhrasesSimilar(
phrasesCollectionName,
params.phrasesCollectionName,
extract_search_queries,
"original",
);
Expand All @@ -111,7 +122,7 @@ export async function ragPipeline(
}
start = performance.now();
const search_response = await retrieveAllByUrl(
docsCollectionName,
params.docsCollectionName,
search_phrase_hits,
);
durations["execute_searches"] = round(lapTimer(start));
Expand Down Expand Up @@ -185,9 +196,9 @@ export async function ragPipeline(
}

const rerankData = {
user_input: translated_user_input,
documents: searchHits.map((document) =>
document.content_markdown.substring(0, envVar("MAX_SOURCE_LENGTH")),
user_input: params.translated_user_query,
documents: searchHits.map((document: any) =>
document.content_markdown.substring(0, params.maxSourceLength),
),
};

Expand Down Expand Up @@ -227,11 +238,11 @@ export async function ragPipeline(

console.log("Source desc: " + sourceDesc);

let docTrimmed = docMd.substring(0, envVar("MAX_SOURCE_LENGTH"));
if (docsLength + docTrimmed.length > envVar("MAX_CONTEXT_LENGTH")) {
let docTrimmed = docMd.substring(0, params.maxSourceLength);
if (docsLength + docTrimmed.length > params.maxContextLength) {
docTrimmed = docTrimmed.substring(
0,
envVar("MAX_CONTEXT_LENGTH") - docsLength - 20,
params.maxContextLength - docsLength - 20,
);
}

Expand Down Expand Up @@ -259,8 +270,8 @@ export async function ragPipeline(
// TODO: add actual doc length, loaded doc length to dedicated lists and return

if (
docsLength >= envVar("MAX_CONTEXT_LENGTH") ||
loadedDocs.length >= envVar("MAX_CONTEXT_DOC_COUNT")
docsLength >= params.maxContextLength ||
loadedDocs.length >= params.maxSourceDocCount
) {
console.log(`Limits reached, loaded ${loadedDocs.length} docs.`);
break;
Expand All @@ -285,16 +296,16 @@ export async function ragPipeline(
let relevant_sources: string[] = [];

const contextYaml = yaml.dump(loadedDocs);
const partialPrompt = promptRagGenerate;
const partialPrompt = params.promptRagGenerate;
const fullPrompt = partialPrompt
.replace("{context}", contextYaml)
.replace("{question}", translated_user_input);
.replace("{question}", params.translated_user_query);

if (envVar("LOG_LEVEL") == "debug") {
console.log(`rag prompt:\n${partialPrompt}`);
}

if (typeof stream_callback_msg1 !== "function") {
if (typeof params.stream_callback_msg1 !== "function") {
if (envVar("USE_AZURE_OPENAI_API") === "true") {
// const chatResponse = await azureClient.chat.completions.create({
// model: envVar('AZURE_OPENAI_DEPLOYMENT'),
Expand Down Expand Up @@ -338,7 +349,9 @@ export async function ragPipeline(
content: fullPrompt,
},
],
stream_callback_msg1,
params.stream_callback_msg1,
params.streamCallbackFreqSec || 2.0,
params.maxResponseTokenCount,
);
translated_answer = english_answer;
rag_success = true;
Expand All @@ -353,22 +366,22 @@ export async function ragPipeline(
if (
translation_enabled &&
rag_success &&
user_query_language_name !== "English"
params.user_query_language_name !== "English"
) {
translated_answer = await translate(
english_answer,
user_query_language_name,
stream_callback_msg2,
params.user_query_language_name,
params.stream_callback_msg2,
);
}

durations["translation"] = round(lapTimer(start));
durations["total"] = round(lapTimer(total_start));

const response: RagPipelineResult = {
original_user_query: original_user_input,
english_user_query: translated_user_input,
user_query_language_name,
original_user_query: params.original_user_query,
english_user_query: params.translated_user_query,
user_query_language_name: params.user_query_language_name,
english_answer: english_answer || "",
translated_answer: translated_answer || "",
rag_success,
Expand All @@ -379,8 +392,8 @@ export async function ragPipeline(
not_loaded_urls: notLoadedUrls,
durations,
prompts: {
queryRelax: promptRagQueryRelax || "",
generate: promptRagGenerate || "",
queryRelax: params.promptRagQueryRelax || "",
generate: params.promptRagGenerate || "",
},
};

Expand Down
2 changes: 2 additions & 0 deletions packages/assistant-lib/src/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ export async function chat_stream(
messages: Array<ChatCompletionMessageParam>,
callback: (arg0: string) => void,
callback_interval_seconds = 2.0,
max_tokens: number | null = null,
): Promise<string> {
let content_so_far = "";
let latest_chunk = "";
Expand Down Expand Up @@ -140,6 +141,7 @@ export async function chat_stream(
model: envVar("OPENAI_API_MODEL_NAME"),
temperature: 0.1,
messages: messages,
max_tokens: max_tokens,
stream: true,
});

Expand Down

0 comments on commit 167cea8

Please sign in to comment.