Merge remote-tracking branch 'origin/dev'
This commit is contained in:
+11
-1
@@ -7,7 +7,9 @@ import {
|
|||||||
OpenAIApi
|
OpenAIApi
|
||||||
} from 'openai';
|
} from 'openai';
|
||||||
|
|
||||||
import { CONFIG_MODES, getConfig } from './commands/config';
|
import {CONFIG_MODES, DEFAULT_MODEL_TOKEN_LIMIT, getConfig} from './commands/config';
|
||||||
|
import {tokenCount} from './utils/tokenCount';
|
||||||
|
import {GenerateCommitMessageErrorEnum} from './generateCommitMessageFromGitDiff';
|
||||||
|
|
||||||
const config = getConfig();
|
const config = getConfig();
|
||||||
|
|
||||||
@@ -56,6 +58,14 @@ class OpenAi {
|
|||||||
max_tokens: maxTokens || 500
|
max_tokens: maxTokens || 500
|
||||||
};
|
};
|
||||||
try {
|
try {
|
||||||
|
const REQUEST_TOKENS = messages.map(
|
||||||
|
(msg) => tokenCount(msg.content) + 4
|
||||||
|
).reduce((a, b) => a + b, 0);
|
||||||
|
|
||||||
|
if (REQUEST_TOKENS > (DEFAULT_MODEL_TOKEN_LIMIT - maxTokens)) {
|
||||||
|
throw new Error(GenerateCommitMessageErrorEnum.tooMuchTokens);
|
||||||
|
}
|
||||||
|
|
||||||
const { data } = await this.openAI.createChatCompletion(params);
|
const { data } = await this.openAI.createChatCompletion(params);
|
||||||
|
|
||||||
const message = data.choices[0].message;
|
const message = data.choices[0].message;
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ export enum CONFIG_KEYS {
|
|||||||
OCO_LANGUAGE = 'OCO_LANGUAGE'
|
OCO_LANGUAGE = 'OCO_LANGUAGE'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const DEFAULT_MODEL_TOKEN_LIMIT = 4096;
|
||||||
|
|
||||||
export enum CONFIG_MODES {
|
export enum CONFIG_MODES {
|
||||||
get = 'get',
|
get = 'get',
|
||||||
set = 'set'
|
set = 'set'
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import {
|
|||||||
ChatCompletionRequestMessageRoleEnum
|
ChatCompletionRequestMessageRoleEnum
|
||||||
} from 'openai';
|
} from 'openai';
|
||||||
import {api} from './api';
|
import {api} from './api';
|
||||||
import { getConfig } from './commands/config';
|
import {DEFAULT_MODEL_TOKEN_LIMIT, getConfig} from './commands/config';
|
||||||
import {mergeDiffs} from './utils/mergeDiffs';
|
import {mergeDiffs} from './utils/mergeDiffs';
|
||||||
import {i18n, I18nLocals} from './i18n';
|
import {i18n, I18nLocals} from './i18n';
|
||||||
import {tokenCount} from './utils/tokenCount';
|
import {tokenCount} from './utils/tokenCount';
|
||||||
@@ -74,24 +74,32 @@ export enum GenerateCommitMessageErrorEnum {
|
|||||||
emptyMessage = 'EMPTY_MESSAGE'
|
emptyMessage = 'EMPTY_MESSAGE'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
const INIT_MESSAGES_PROMPT_LENGTH = INIT_MESSAGES_PROMPT.map(
|
const INIT_MESSAGES_PROMPT_LENGTH = INIT_MESSAGES_PROMPT.map(
|
||||||
(msg) => tokenCount(msg.content) + 4
|
(msg) => tokenCount(msg.content) + 4
|
||||||
).reduce((a, b) => a + b, 0);
|
).reduce((a, b) => a + b, 0);
|
||||||
|
|
||||||
const MAX_REQ_TOKENS = 3000 - INIT_MESSAGES_PROMPT_LENGTH;
|
const ADJUSTMENT_FACTOR = 20;
|
||||||
|
|
||||||
export const generateCommitMessageByDiff = async (
|
export const generateCommitMessageByDiff = async (
|
||||||
diff: string
|
diff: string
|
||||||
): Promise<string> => {
|
): Promise<string> => {
|
||||||
try {
|
try {
|
||||||
if (tokenCount(diff) >= MAX_REQ_TOKENS) {
|
const MAX_REQUEST_TOKENS = DEFAULT_MODEL_TOKEN_LIMIT
|
||||||
|
- ADJUSTMENT_FACTOR
|
||||||
|
- INIT_MESSAGES_PROMPT_LENGTH
|
||||||
|
- config?.OCO_OPENAI_MAX_TOKENS;
|
||||||
|
|
||||||
|
if (tokenCount(diff) >= MAX_REQUEST_TOKENS) {
|
||||||
const commitMessagePromises = getCommitMsgsPromisesFromFileDiffs(
|
const commitMessagePromises = getCommitMsgsPromisesFromFileDiffs(
|
||||||
diff,
|
diff,
|
||||||
MAX_REQ_TOKENS
|
MAX_REQUEST_TOKENS
|
||||||
);
|
);
|
||||||
|
|
||||||
const commitMessages = await Promise.all(commitMessagePromises);
|
const commitMessages = [];
|
||||||
|
for (const promise of commitMessagePromises) {
|
||||||
|
commitMessages.push(await promise);
|
||||||
|
await delay(2000);
|
||||||
|
}
|
||||||
|
|
||||||
return commitMessages.join('\n\n');
|
return commitMessages.join('\n\n');
|
||||||
} else {
|
} else {
|
||||||
@@ -123,9 +131,17 @@ function getMessagesPromisesByChangesInFile(
|
|||||||
maxChangeLength
|
maxChangeLength
|
||||||
);
|
);
|
||||||
|
|
||||||
const lineDiffsWithHeader = mergedChanges.map(
|
const lineDiffsWithHeader = [];
|
||||||
(change) => fileHeader + change
|
for (const change of mergedChanges) {
|
||||||
);
|
const totalChange = fileHeader + change;
|
||||||
|
if (tokenCount(totalChange) > maxChangeLength) {
|
||||||
|
// If the totalChange is too large, split it into smaller pieces
|
||||||
|
const splitChanges = splitDiff(totalChange, maxChangeLength);
|
||||||
|
lineDiffsWithHeader.push(...splitChanges);
|
||||||
|
} else {
|
||||||
|
lineDiffsWithHeader.push(totalChange);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const commitMsgsFromFileLineDiffs = lineDiffsWithHeader.map((lineDiff) => {
|
const commitMsgsFromFileLineDiffs = lineDiffsWithHeader.map((lineDiff) => {
|
||||||
const messages = generateCommitMessageChatCompletionPrompt(
|
const messages = generateCommitMessageChatCompletionPrompt(
|
||||||
@@ -138,6 +154,39 @@ function getMessagesPromisesByChangesInFile(
|
|||||||
return commitMsgsFromFileLineDiffs;
|
return commitMsgsFromFileLineDiffs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
function splitDiff(diff: string, maxChangeLength: number) {
|
||||||
|
const lines = diff.split('\n');
|
||||||
|
const splitDiffs = [];
|
||||||
|
let currentDiff = '';
|
||||||
|
|
||||||
|
for (let line of lines) {
|
||||||
|
// If a single line exceeds maxChangeLength, split it into multiple lines
|
||||||
|
while (tokenCount(line) > maxChangeLength) {
|
||||||
|
const subLine = line.substring(0, maxChangeLength);
|
||||||
|
line = line.substring(maxChangeLength);
|
||||||
|
splitDiffs.push(subLine);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the tokenCount of the currentDiff and the line separately
|
||||||
|
if (tokenCount(currentDiff) + tokenCount('\n' + line) > maxChangeLength) {
|
||||||
|
// If adding the next line would exceed the maxChangeLength, start a new diff
|
||||||
|
splitDiffs.push(currentDiff);
|
||||||
|
currentDiff = line;
|
||||||
|
} else {
|
||||||
|
// Otherwise, add the line to the current diff
|
||||||
|
currentDiff += '\n' + line;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the last diff
|
||||||
|
if (currentDiff) {
|
||||||
|
splitDiffs.push(currentDiff);
|
||||||
|
}
|
||||||
|
|
||||||
|
return splitDiffs;
|
||||||
|
}
|
||||||
|
|
||||||
export function getCommitMsgsPromisesFromFileDiffs(
|
export function getCommitMsgsPromisesFromFileDiffs(
|
||||||
diff: string,
|
diff: string,
|
||||||
maxDiffLength: number
|
maxDiffLength: number
|
||||||
@@ -169,5 +218,11 @@ export function getCommitMsgsPromisesFromFileDiffs(
|
|||||||
commitMessagePromises.push(api.generateCommitMessage(messages));
|
commitMessagePromises.push(api.generateCommitMessage(messages));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
return commitMessagePromises;
|
return commitMessagePromises;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function delay(ms: number) {
|
||||||
|
return new Promise(resolve => setTimeout(resolve, ms));
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user