diff --git a/src/generateCommitMessageFromGitDiff.ts b/src/generateCommitMessageFromGitDiff.ts index c3fb638..82b8bde 100644 --- a/src/generateCommitMessageFromGitDiff.ts +++ b/src/generateCommitMessageFromGitDiff.ts @@ -4,6 +4,7 @@ import { } from 'openai'; import { api } from './api'; import { getConfig } from './commands/config'; +import { mergeStrings } from './utils/mergeStrings'; const config = getConfig(); @@ -88,37 +89,64 @@ export const generateCommitMessageWithChatCompletion = async ( ): Promise => { try { if (diff.length >= MAX_REQ_TOKENS) { - const separator = 'diff --git '; - - const diffByFiles = diff.split(separator).slice(1); - - const commitMessagePromises = diffByFiles - .map((fileDiff) => { - // TODO: split by files - if (fileDiff.length >= MAX_REQ_TOKENS) return null; - - const messages = generateCommitMessageChatCompletionPrompt( - separator + fileDiff - ); - - return api.generateCommitMessage(messages); - }) - .filter(Boolean); + const commitMessagePromises = getCommitMsgsPromisesFromFileDiffs(diff); const commitMessages = await Promise.all(commitMessagePromises); return commitMessages.join('\n\n'); - } - - const messages = generateCommitMessageChatCompletionPrompt(diff); + } else { + const messages = generateCommitMessageChatCompletionPrompt(diff); - const commitMessage = await api.generateCommitMessage(messages); + const commitMessage = await api.generateCommitMessage(messages); - if (!commitMessage) - return { error: GenerateCommitMessageErrorEnum.emptyMessage }; + if (!commitMessage) + return { error: GenerateCommitMessageErrorEnum.emptyMessage }; - return commitMessage; + return commitMessage; + } } catch (error) { return { error: GenerateCommitMessageErrorEnum.internalError }; } }; + +function getMessagesPromisesByLines(fileDiff: string, separator: string) { + const [fileHeader, ...fileDiffByLines] = fileDiff.split('@@'); + const lineDiffsWithHeader = fileDiffByLines.map((d) => fileHeader + '@@' + d); + + const mergedLines = mergeStrings(lineDiffsWithHeader, MAX_REQ_TOKENS); + + const commitMsgsFromFileLineDiffs = mergedLines.map((d) => { + const messages = generateCommitMessageChatCompletionPrompt(separator + d); + + return api.generateCommitMessage(messages); + }); + + return commitMsgsFromFileLineDiffs; +} + +function getCommitMsgsPromisesFromFileDiffs(diff: string) { + const separator = 'diff --git '; + + const diffByFiles = diff.split(separator).slice(1); + + const mergedDiffs = mergeStrings(diffByFiles, MAX_REQ_TOKENS); + + const commitMessagePromises = []; + + for (const fileDiff of mergedDiffs) { + if (fileDiff.length >= MAX_REQ_TOKENS) { + // split fileDiff into lineDiff + const messagesPromises = getMessagesPromisesByLines(fileDiff, separator); + + commitMessagePromises.push(...messagesPromises); + } else { + // generate commits for files + const messages = generateCommitMessageChatCompletionPrompt( + separator + fileDiff + ); + + commitMessagePromises.push(api.generateCommitMessage(messages)); + } + } + return commitMessagePromises; +} diff --git a/src/utils/mergeStrings.ts b/src/utils/mergeStrings.ts new file mode 100644 index 0000000..a8beb37 --- /dev/null +++ b/src/utils/mergeStrings.ts @@ -0,0 +1,14 @@ +export function mergeStrings(arr: string[], maxStringLength: number): string[] { + const mergedArr: string[] = []; + let currentItem: string = arr[0]; + for (const item of arr.slice(1)) { + if (currentItem.length + item.length <= maxStringLength) { + currentItem += item; + } else { + mergedArr.push(currentItem); + currentItem = item; + } + } + mergedArr.push(currentItem); + return mergedArr; +}