Merge remote-tracking branch 'origin/dev'

This commit is contained in:
di-sukharev
2023-05-26 13:07:42 +08:00
3 changed files with 165 additions and 98 deletions
+11 -1
View File
@@ -7,7 +7,9 @@ import {
OpenAIApi OpenAIApi
} from 'openai'; } from 'openai';
import { CONFIG_MODES, getConfig } from './commands/config'; import {CONFIG_MODES, DEFAULT_MODEL_TOKEN_LIMIT, getConfig} from './commands/config';
import {tokenCount} from './utils/tokenCount';
import {GenerateCommitMessageErrorEnum} from './generateCommitMessageFromGitDiff';
const config = getConfig(); const config = getConfig();
@@ -56,6 +58,14 @@ class OpenAi {
max_tokens: maxTokens || 500 max_tokens: maxTokens || 500
}; };
try { try {
const REQUEST_TOKENS = messages.map(
(msg) => tokenCount(msg.content) + 4
).reduce((a, b) => a + b, 0);
if (REQUEST_TOKENS > (DEFAULT_MODEL_TOKEN_LIMIT - maxTokens)) {
throw new Error(GenerateCommitMessageErrorEnum.tooMuchTokens);
}
const { data } = await this.openAI.createChatCompletion(params); const { data } = await this.openAI.createChatCompletion(params);
const message = data.choices[0].message; const message = data.choices[0].message;
+2
View File
@@ -22,6 +22,8 @@ export enum CONFIG_KEYS {
OCO_LANGUAGE = 'OCO_LANGUAGE' OCO_LANGUAGE = 'OCO_LANGUAGE'
} }
export const DEFAULT_MODEL_TOKEN_LIMIT = 4096;
export enum CONFIG_MODES { export enum CONFIG_MODES {
get = 'get', get = 'get',
set = 'set' set = 'set'
+64 -9
View File
@@ -3,7 +3,7 @@ import {
ChatCompletionRequestMessageRoleEnum ChatCompletionRequestMessageRoleEnum
} from 'openai'; } from 'openai';
import {api} from './api'; import {api} from './api';
import { getConfig } from './commands/config'; import {DEFAULT_MODEL_TOKEN_LIMIT, getConfig} from './commands/config';
import {mergeDiffs} from './utils/mergeDiffs'; import {mergeDiffs} from './utils/mergeDiffs';
import {i18n, I18nLocals} from './i18n'; import {i18n, I18nLocals} from './i18n';
import {tokenCount} from './utils/tokenCount'; import {tokenCount} from './utils/tokenCount';
@@ -74,24 +74,32 @@ export enum GenerateCommitMessageErrorEnum {
emptyMessage = 'EMPTY_MESSAGE' emptyMessage = 'EMPTY_MESSAGE'
} }
const INIT_MESSAGES_PROMPT_LENGTH = INIT_MESSAGES_PROMPT.map( const INIT_MESSAGES_PROMPT_LENGTH = INIT_MESSAGES_PROMPT.map(
(msg) => tokenCount(msg.content) + 4 (msg) => tokenCount(msg.content) + 4
).reduce((a, b) => a + b, 0); ).reduce((a, b) => a + b, 0);
const MAX_REQ_TOKENS = 3000 - INIT_MESSAGES_PROMPT_LENGTH; const ADJUSTMENT_FACTOR = 20;
export const generateCommitMessageByDiff = async ( export const generateCommitMessageByDiff = async (
diff: string diff: string
): Promise<string> => { ): Promise<string> => {
try { try {
if (tokenCount(diff) >= MAX_REQ_TOKENS) { const MAX_REQUEST_TOKENS = DEFAULT_MODEL_TOKEN_LIMIT
- ADJUSTMENT_FACTOR
- INIT_MESSAGES_PROMPT_LENGTH
- config?.OCO_OPENAI_MAX_TOKENS;
if (tokenCount(diff) >= MAX_REQUEST_TOKENS) {
const commitMessagePromises = getCommitMsgsPromisesFromFileDiffs( const commitMessagePromises = getCommitMsgsPromisesFromFileDiffs(
diff, diff,
MAX_REQ_TOKENS MAX_REQUEST_TOKENS
); );
const commitMessages = await Promise.all(commitMessagePromises); const commitMessages = [];
for (const promise of commitMessagePromises) {
commitMessages.push(await promise);
await delay(2000);
}
return commitMessages.join('\n\n'); return commitMessages.join('\n\n');
} else { } else {
@@ -123,9 +131,17 @@ function getMessagesPromisesByChangesInFile(
maxChangeLength maxChangeLength
); );
const lineDiffsWithHeader = mergedChanges.map( const lineDiffsWithHeader = [];
(change) => fileHeader + change for (const change of mergedChanges) {
); const totalChange = fileHeader + change;
if (tokenCount(totalChange) > maxChangeLength) {
// If the totalChange is too large, split it into smaller pieces
const splitChanges = splitDiff(totalChange, maxChangeLength);
lineDiffsWithHeader.push(...splitChanges);
} else {
lineDiffsWithHeader.push(totalChange);
}
}
const commitMsgsFromFileLineDiffs = lineDiffsWithHeader.map((lineDiff) => { const commitMsgsFromFileLineDiffs = lineDiffsWithHeader.map((lineDiff) => {
const messages = generateCommitMessageChatCompletionPrompt( const messages = generateCommitMessageChatCompletionPrompt(
@@ -138,6 +154,39 @@ function getMessagesPromisesByChangesInFile(
return commitMsgsFromFileLineDiffs; return commitMsgsFromFileLineDiffs;
} }
function splitDiff(diff: string, maxChangeLength: number) {
const lines = diff.split('\n');
const splitDiffs = [];
let currentDiff = '';
for (let line of lines) {
// If a single line exceeds maxChangeLength, split it into multiple lines
while (tokenCount(line) > maxChangeLength) {
const subLine = line.substring(0, maxChangeLength);
line = line.substring(maxChangeLength);
splitDiffs.push(subLine);
}
// Check the tokenCount of the currentDiff and the line separately
if (tokenCount(currentDiff) + tokenCount('\n' + line) > maxChangeLength) {
// If adding the next line would exceed the maxChangeLength, start a new diff
splitDiffs.push(currentDiff);
currentDiff = line;
} else {
// Otherwise, add the line to the current diff
currentDiff += '\n' + line;
}
}
// Add the last diff
if (currentDiff) {
splitDiffs.push(currentDiff);
}
return splitDiffs;
}
export function getCommitMsgsPromisesFromFileDiffs( export function getCommitMsgsPromisesFromFileDiffs(
diff: string, diff: string,
maxDiffLength: number maxDiffLength: number
@@ -169,5 +218,11 @@ export function getCommitMsgsPromisesFromFileDiffs(
commitMessagePromises.push(api.generateCommitMessage(messages)); commitMessagePromises.push(api.generateCommitMessage(messages));
} }
} }
return commitMessagePromises; return commitMessagePromises;
} }
function delay(ms: number) {
return new Promise(resolve => setTimeout(resolve, ms));
}