Feat/add gemini (#349)
This commit is contained in:
@@ -110,7 +110,7 @@ OCO_TOKENS_MAX_OUTPUT=<max response tokens (default: 500)>
|
||||
OCO_OPENAI_BASE_PATH=<may be used to set proxy path to OpenAI api>
|
||||
OCO_DESCRIPTION=<postface a message with ~3 sentences description of the changes>
|
||||
OCO_EMOJI=<boolean, add GitMoji>
|
||||
OCO_MODEL=<either 'gpt-4', 'gpt-4-turbo', 'gpt-3.5-turbo' (default), 'gpt-3.5-turbo-0125', 'gpt-4-1106-preview', 'gpt-4-turbo-preview' or 'gpt-4-0125-preview'>
|
||||
OCO_MODEL=<either 'gpt-4o', 'gpt-4', 'gpt-4-turbo', 'gpt-3.5-turbo' (default), 'gpt-3.5-turbo-0125', 'gpt-4-1106-preview', 'gpt-4-turbo-preview' or 'gpt-4-0125-preview'>
|
||||
OCO_LANGUAGE=<locale, scroll to the bottom to see options>
|
||||
OCO_MESSAGE_TEMPLATE_PLACEHOLDER=<message template placeholder, default: '$msg'>
|
||||
OCO_PROMPT_MODULE=<either conventional-commit or @commitlint, default: conventional-commit>
|
||||
|
||||
+4422
-3522
File diff suppressed because it is too large
Load Diff
+4487
-3590
File diff suppressed because it is too large
Load Diff
Generated
+9
@@ -16,6 +16,7 @@
|
||||
"@azure/openai": "^1.0.0-beta.12",
|
||||
"@clack/prompts": "^0.6.1",
|
||||
"@dqbd/tiktoken": "^1.0.2",
|
||||
"@google/generative-ai": "^0.11.4",
|
||||
"@octokit/webhooks-schemas": "^6.11.0",
|
||||
"@octokit/webhooks-types": "^6.11.0",
|
||||
"ai": "^2.2.14",
|
||||
@@ -1051,6 +1052,14 @@
|
||||
"node": ">=14"
|
||||
}
|
||||
},
|
||||
"node_modules/@google/generative-ai": {
|
||||
"version": "0.11.4",
|
||||
"resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.11.4.tgz",
|
||||
"integrity": "sha512-hlw+E9Prv9aUIQISRnLSXi4rukFqKe5WhxPvzBccTvIvXjw2BHMFOJWSC/Gq7WE0W+L/qRHGmYxopmx9qjrB9w==",
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@humanwhocodes/config-array": {
|
||||
"version": "0.11.14",
|
||||
"resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.14.tgz",
|
||||
|
||||
@@ -43,11 +43,13 @@
|
||||
"start": "node ./out/cli.cjs",
|
||||
"ollama:start": "OCO_AI_PROVIDER='ollama' node ./out/cli.cjs",
|
||||
"dev": "ts-node ./src/cli.ts",
|
||||
"dev:gemini": "OCO_AI_PROVIDER='gemini' ts-node ./src/cli.ts",
|
||||
"build": "rimraf out && node esbuild.config.js",
|
||||
"build:push": "npm run build && git add . && git commit -m 'build' && git push",
|
||||
"deploy": "npm version patch && npm run build:push && git push --tags && npm publish --tag latest",
|
||||
"lint": "eslint src --ext ts && tsc --noEmit",
|
||||
"format": "prettier --write src",
|
||||
"test": "node --no-warnings --experimental-vm-modules $( [ -f ./node_modules/.bin/jest ] && echo ./node_modules/.bin/jest || which jest ) test/unit",
|
||||
"test:all": "npm run test:unit:docker && npm run test:e2e:docker",
|
||||
"test:docker-build": "docker build -t oco-test -f test/Dockerfile .",
|
||||
"test:unit": "NODE_OPTIONS=--experimental-vm-modules jest test/unit",
|
||||
@@ -81,6 +83,7 @@
|
||||
"@anthropic-ai/sdk": "^0.19.2",
|
||||
"@clack/prompts": "^0.6.1",
|
||||
"@dqbd/tiktoken": "^1.0.2",
|
||||
"@google/generative-ai": "^0.11.4",
|
||||
"@octokit/webhooks-schemas": "^6.11.0",
|
||||
"@octokit/webhooks-types": "^6.11.0",
|
||||
"ai": "^2.2.14",
|
||||
|
||||
+64
-23
@@ -15,6 +15,8 @@ export enum CONFIG_KEYS {
|
||||
OCO_OPENAI_API_KEY = 'OCO_OPENAI_API_KEY',
|
||||
OCO_ANTHROPIC_API_KEY = 'OCO_ANTHROPIC_API_KEY',
|
||||
OCO_AZURE_API_KEY = 'OCO_AZURE_API_KEY',
|
||||
OCO_GEMINI_API_KEY = 'OCO_GEMINI_API_KEY',
|
||||
OCO_GEMINI_BASE_PATH = 'OCO_GEMINI_BASE_PATH',
|
||||
OCO_TOKENS_MAX_INPUT = 'OCO_TOKENS_MAX_INPUT',
|
||||
OCO_TOKENS_MAX_OUTPUT = 'OCO_TOKENS_MAX_OUTPUT',
|
||||
OCO_OPENAI_BASE_PATH = 'OCO_OPENAI_BASE_PATH',
|
||||
@@ -36,18 +38,32 @@ export enum CONFIG_MODES {
|
||||
}
|
||||
|
||||
export const MODEL_LIST = {
|
||||
openai: ['gpt-3.5-turbo',
|
||||
'gpt-3.5-turbo-0125',
|
||||
'gpt-4',
|
||||
'gpt-4-turbo',
|
||||
'gpt-4-1106-preview',
|
||||
'gpt-4-turbo-preview',
|
||||
'gpt-4-0125-preview',
|
||||
'gpt-4o'],
|
||||
|
||||
openai: [
|
||||
'gpt-3.5-turbo',
|
||||
'gpt-3.5-turbo-0125',
|
||||
'gpt-4',
|
||||
'gpt-4-turbo',
|
||||
'gpt-4-1106-preview',
|
||||
'gpt-4-turbo-preview',
|
||||
'gpt-4-0125-preview',
|
||||
'gpt-4o',
|
||||
],
|
||||
|
||||
anthropic: ['claude-3-haiku-20240307',
|
||||
'claude-3-sonnet-20240229',
|
||||
'claude-3-opus-20240229']
|
||||
anthropic: [
|
||||
'claude-3-haiku-20240307',
|
||||
'claude-3-sonnet-20240229',
|
||||
'claude-3-opus-20240229',
|
||||
],
|
||||
|
||||
gemini: [
|
||||
'gemini-1.5-flash',
|
||||
'gemini-1.5-pro',
|
||||
'gemini-1.0-pro',
|
||||
'gemini-pro-vision',
|
||||
'text-embedding-004',
|
||||
],
|
||||
|
||||
}
|
||||
|
||||
const getDefaultModel = (provider: string | undefined): string => {
|
||||
@@ -56,6 +72,8 @@ const getDefaultModel = (provider: string | undefined): string => {
|
||||
return '';
|
||||
case 'anthropic':
|
||||
return MODEL_LIST.anthropic[0];
|
||||
case 'gemini':
|
||||
return MODEL_LIST.gemini[0];
|
||||
default:
|
||||
return MODEL_LIST.openai[0];
|
||||
}
|
||||
@@ -82,6 +100,8 @@ const validateConfig = (
|
||||
|
||||
export const configValidators = {
|
||||
[CONFIG_KEYS.OCO_OPENAI_API_KEY](value: any, config: any = {}) {
|
||||
if (config.OCO_AI_PROVIDER == 'gemini') return value;
|
||||
|
||||
//need api key unless running locally with ollama
|
||||
validateConfig(
|
||||
'OpenAI API_KEY',
|
||||
@@ -116,6 +136,29 @@ export const configValidators = {
|
||||
|
||||
return value;
|
||||
},
|
||||
|
||||
[CONFIG_KEYS.OCO_GEMINI_API_KEY](value: any, config: any = {}) {
|
||||
// only need to check for gemini api key if using gemini
|
||||
if (config.OCO_AI_PROVIDER != 'gemini') return value;
|
||||
|
||||
validateConfig(
|
||||
'Gemini API Key',
|
||||
value || config.OCO_GEMINI_API_KEY || config.OCO_AI_PROVIDER == 'test',
|
||||
'You need to provide an Gemini API key'
|
||||
);
|
||||
|
||||
return value;
|
||||
},
|
||||
|
||||
[CONFIG_KEYS.OCO_ANTHROPIC_API_KEY](value: any, config: any = {}) {
|
||||
validateConfig(
|
||||
'ANTHROPIC_API_KEY',
|
||||
value || config.OCO_OPENAI_API_KEY || config.OCO_AI_PROVIDER == 'ollama' || config.OCO_AI_PROVIDER == 'test',
|
||||
'You need to provide an OpenAI/Anthropic API key'
|
||||
);
|
||||
|
||||
return value;
|
||||
},
|
||||
|
||||
[CONFIG_KEYS.OCO_DESCRIPTION](value: any) {
|
||||
validateConfig(
|
||||
@@ -196,15 +239,11 @@ export const configValidators = {
|
||||
[CONFIG_KEYS.OCO_MODEL](value: any, config: any = {}) {
|
||||
validateConfig(
|
||||
CONFIG_KEYS.OCO_MODEL,
|
||||
[...MODEL_LIST.openai, ...MODEL_LIST.anthropic].includes(value) || config.OCO_AI_PROVIDER == 'ollama' || config.OCO_AI_PROVIDER == 'test'|| config.OCO_AI_PROVIDER == 'azure',
|
||||
`${value} is not supported yet, use 'gpt-4o', 'gpt-4', 'gpt-4-turbo', 'gpt-3.5-turbo' (default), 'gpt-3.5-turbo-0125', 'gpt-4-1106-preview', 'gpt-4-turbo-preview', 'gpt-4-0125-preview', 'claude-3-opus-20240229', 'claude-3-sonnet-20240229' or 'claude-3-haiku-20240307'`
|
||||
);
|
||||
validateConfig(
|
||||
CONFIG_KEYS.OCO_MODEL,
|
||||
typeof value === 'string' &&
|
||||
value.match(/^[a-zA-Z0-9~\-]{1,63}[a-zA-Z0-9]$/) ||
|
||||
config.OCO_AI_PROVIDER != 'azure',
|
||||
`${value} is not model deployed name.`
|
||||
[...MODEL_LIST.openai, ...MODEL_LIST.anthropic, ...MODEL_LIST.gemini].includes(value) ||
|
||||
config.OCO_AI_PROVIDER == 'ollama' ||
|
||||
config.OCO_AI_PROVIDER == 'azure' ||
|
||||
config.OCO_AI_PROVIDER == 'test',
|
||||
`${value} is not supported yet, use:\n\n ${[...MODEL_LIST.openai, ...MODEL_LIST.anthropic, ...MODEL_LIST.gemini].join('\n')}`
|
||||
);
|
||||
return value;
|
||||
},
|
||||
@@ -243,11 +282,11 @@ export const configValidators = {
|
||||
'',
|
||||
'openai',
|
||||
'anthropic',
|
||||
'azure',
|
||||
'ollama',
|
||||
'gemini',
|
||||
'azure',
|
||||
'test'
|
||||
].includes(value) || value.startsWith('ollama'),
|
||||
`${value} is not supported yet, use 'ollama/{model}', 'azure', 'anthropic' or 'openai' (default)`
|
||||
`${value} is not supported yet, use 'ollama', 'anthropic', 'azure', 'gemini' or 'openai' (default)`
|
||||
);
|
||||
return value;
|
||||
},
|
||||
@@ -291,6 +330,7 @@ export const getConfig = ({
|
||||
OCO_OPENAI_API_KEY: process.env.OCO_OPENAI_API_KEY,
|
||||
OCO_ANTHROPIC_API_KEY: process.env.OCO_ANTHROPIC_API_KEY,
|
||||
OCO_AZURE_API_KEY: process.env.OCO_AZURE_API_KEY,
|
||||
OCO_GEMINI_API_KEY: process.env.OCO_GEMINI_API_KEY,
|
||||
OCO_TOKENS_MAX_INPUT: process.env.OCO_TOKENS_MAX_INPUT
|
||||
? Number(process.env.OCO_TOKENS_MAX_INPUT)
|
||||
: undefined,
|
||||
@@ -298,6 +338,7 @@ export const getConfig = ({
|
||||
? Number(process.env.OCO_TOKENS_MAX_OUTPUT)
|
||||
: undefined,
|
||||
OCO_OPENAI_BASE_PATH: process.env.OCO_OPENAI_BASE_PATH,
|
||||
OCO_GEMINI_BASE_PATH: process.env.OCO_GEMINI_BASE_PATH,
|
||||
OCO_DESCRIPTION: process.env.OCO_DESCRIPTION === 'true' ? true : false,
|
||||
OCO_EMOJI: process.env.OCO_EMOJI === 'true' ? true : false,
|
||||
OCO_MODEL: process.env.OCO_MODEL || getDefaultModel(process.env.OCO_AI_PROVIDER),
|
||||
|
||||
@@ -59,7 +59,7 @@ if (provider === 'anthropic' &&
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
class AnthropicAi implements AiEngine {
|
||||
export class AnthropicAi implements AiEngine {
|
||||
private anthropicAiApiConfiguration = {
|
||||
apiKey: apiKey
|
||||
};
|
||||
@@ -120,5 +120,3 @@ class AnthropicAi implements AiEngine {
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
export const anthropicAi = new AnthropicAi();
|
||||
+1
-1
@@ -54,7 +54,7 @@ if (
|
||||
|
||||
const MODEL = config?.OCO_MODEL || 'gpt-3.5-turbo';
|
||||
|
||||
class Azure implements AiEngine {
|
||||
export class Azure implements AiEngine {
|
||||
private openAI!: OpenAIClient;
|
||||
|
||||
constructor() {
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
import { ChatCompletionRequestMessage } from 'openai';
|
||||
import { AiEngine } from './Engine';
|
||||
import { Content, GenerativeModel, GoogleGenerativeAI, HarmBlockThreshold, HarmCategory, Part } from '@google/generative-ai';
|
||||
import { CONFIG_MODES, ConfigType, DEFAULT_TOKEN_LIMITS, getConfig, MODEL_LIST } from '../commands/config';
|
||||
import { intro, outro } from '@clack/prompts';
|
||||
import chalk from 'chalk';
|
||||
import axios from 'axios';
|
||||
|
||||
|
||||
export class Gemini implements AiEngine {
|
||||
|
||||
private readonly config: ConfigType;
|
||||
private readonly googleGenerativeAi: GoogleGenerativeAI;
|
||||
private ai: GenerativeModel;
|
||||
|
||||
// vars
|
||||
private maxTokens = {
|
||||
input: DEFAULT_TOKEN_LIMITS.DEFAULT_MAX_TOKENS_INPUT,
|
||||
output: DEFAULT_TOKEN_LIMITS.DEFAULT_MAX_TOKENS_OUTPUT
|
||||
};
|
||||
private basePath: string;
|
||||
private apiKey: string;
|
||||
private model: string;
|
||||
|
||||
constructor() {
|
||||
this.config = getConfig() as ConfigType;
|
||||
this.googleGenerativeAi = new GoogleGenerativeAI(this.config.OCO_GEMINI_API_KEY);
|
||||
|
||||
this.warmup();
|
||||
}
|
||||
|
||||
async generateCommitMessage(messages: ChatCompletionRequestMessage[]): Promise<string | undefined> {
|
||||
const systemInstruction = messages.filter(m => m.role === 'system')
|
||||
.map(m => m.content)
|
||||
.join('\n');
|
||||
|
||||
this.ai = this.googleGenerativeAi.getGenerativeModel({
|
||||
model: this.model,
|
||||
systemInstruction,
|
||||
});
|
||||
|
||||
const contents = messages.filter(m => m.role !== 'system')
|
||||
.map(m => ({ parts: [{ text: m.content } as Part], role: m.role == 'user' ? m.role : 'model', } as Content));
|
||||
|
||||
try {
|
||||
const result = await this.ai.generateContent({
|
||||
contents,
|
||||
safetySettings: [
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
|
||||
},
|
||||
],
|
||||
generationConfig: {
|
||||
maxOutputTokens: this.maxTokens.output,
|
||||
temperature: 0,
|
||||
topP: 0.1,
|
||||
},
|
||||
});
|
||||
|
||||
return result.response.text();
|
||||
} catch (error) {
|
||||
const err = error as Error;
|
||||
outro(`${chalk.red('✖')} ${err?.message || err}`);
|
||||
|
||||
if (
|
||||
axios.isAxiosError<{ error?: { message: string } }>(error) &&
|
||||
error.response?.status === 401
|
||||
) {
|
||||
const geminiError = error.response.data.error;
|
||||
|
||||
if (geminiError?.message) outro(geminiError.message);
|
||||
outro(
|
||||
'For help look into README https://github.com/di-sukharev/opencommit#setup'
|
||||
);
|
||||
}
|
||||
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
private warmup(): void {
|
||||
if (this.config.OCO_TOKENS_MAX_INPUT !== undefined) this.maxTokens.input = this.config.OCO_TOKENS_MAX_INPUT;
|
||||
if (this.config.OCO_TOKENS_MAX_OUTPUT !== undefined) this.maxTokens.output = this.config.OCO_TOKENS_MAX_OUTPUT;
|
||||
this.basePath = this.config.OCO_GEMINI_BASE_PATH;
|
||||
this.apiKey = this.config.OCO_GEMINI_API_KEY;
|
||||
|
||||
const [command, mode] = process.argv.slice(2);
|
||||
|
||||
const provider = this.config.OCO_AI_PROVIDER;
|
||||
|
||||
if (provider === 'gemini' && !this.apiKey &&
|
||||
command !== 'config' && mode !== 'set') {
|
||||
intro('opencommit');
|
||||
|
||||
outro('OCO_GEMINI_API_KEY is not set, please run `oco config set OCO_GEMINI_API_KEY=<your token> . If you are using GPT, make sure you add payment details, so API works.');
|
||||
|
||||
outro(
|
||||
'For help look into README https://github.com/di-sukharev/opencommit#setup'
|
||||
);
|
||||
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
this.model = this.config.OCO_MODEL || MODEL_LIST.gemini[0];
|
||||
|
||||
if (provider === 'gemini' &&
|
||||
!MODEL_LIST.gemini.includes(this.model) &&
|
||||
command !== 'config' &&
|
||||
mode !== CONFIG_MODES.set) {
|
||||
outro(
|
||||
`${chalk.red('✖')} Unsupported model ${this.model} for Gemini. Supported models are: ${MODEL_LIST.gemini.join(
|
||||
', '
|
||||
)}`
|
||||
);
|
||||
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -45,5 +45,3 @@ export class OllamaAi implements AiEngine {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const ollamaAi = new OllamaAi();
|
||||
|
||||
@@ -66,7 +66,8 @@ if (provider === 'openai' &&
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
class OpenAi implements AiEngine {
|
||||
export class OpenAi implements AiEngine {
|
||||
|
||||
private openAiApiConfiguration = new OpenAiApiConfiguration({
|
||||
apiKey: apiKey
|
||||
});
|
||||
@@ -91,7 +92,7 @@ class OpenAi implements AiEngine {
|
||||
};
|
||||
try {
|
||||
const REQUEST_TOKENS = messages
|
||||
.map((msg) => tokenCount(msg.content) + 4)
|
||||
.map((msg) => tokenCount(msg.content as string) + 4)
|
||||
.reduce((a, b) => a + b, 0);
|
||||
|
||||
if (REQUEST_TOKENS > MAX_TOKENS_INPUT - MAX_TOKENS_OUTPUT) {
|
||||
@@ -124,6 +125,6 @@ class OpenAi implements AiEngine {
|
||||
throw err;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
export const api = new OpenAi();
|
||||
|
||||
@@ -9,4 +9,3 @@ export class TestAi implements AiEngine {
|
||||
}
|
||||
}
|
||||
|
||||
export const testAi = new TestAi();
|
||||
|
||||
@@ -49,7 +49,7 @@ export const generateCommitMessageByDiff = async (
|
||||
const INIT_MESSAGES_PROMPT = await getMainCommitPrompt(fullGitMojiSpec);
|
||||
|
||||
const INIT_MESSAGES_PROMPT_LENGTH = INIT_MESSAGES_PROMPT.map(
|
||||
(msg) => tokenCount(msg.content) + 4
|
||||
(msg) => tokenCount(msg.content as string) + 4
|
||||
).reduce((a, b) => a + b, 0);
|
||||
|
||||
const MAX_REQUEST_TOKENS =
|
||||
@@ -65,9 +65,9 @@ export const generateCommitMessageByDiff = async (
|
||||
fullGitMojiSpec
|
||||
);
|
||||
|
||||
const commitMessages = [];
|
||||
const commitMessages = [] as string[];
|
||||
for (const promise of commitMessagePromises) {
|
||||
commitMessages.push(await promise);
|
||||
commitMessages.push((await promise) as string);
|
||||
await delay(2000);
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ function getMessagesPromisesByChangesInFile(
|
||||
maxChangeLength
|
||||
);
|
||||
|
||||
const lineDiffsWithHeader = [];
|
||||
const lineDiffsWithHeader = [] as string[];
|
||||
for (const change of mergedChanges) {
|
||||
const totalChange = fileHeader + change;
|
||||
if (tokenCount(totalChange) > maxChangeLength) {
|
||||
@@ -135,7 +135,7 @@ function getMessagesPromisesByChangesInFile(
|
||||
|
||||
function splitDiff(diff: string, maxChangeLength: number) {
|
||||
const lines = diff.split('\n');
|
||||
const splitDiffs = [];
|
||||
const splitDiffs = [] as string[];
|
||||
let currentDiff = '';
|
||||
|
||||
if (maxChangeLength <= 0) {
|
||||
@@ -181,7 +181,7 @@ export const getCommitMsgsPromisesFromFileDiffs = async (
|
||||
// merge multiple files-diffs into 1 prompt to save tokens
|
||||
const mergedFilesDiffs = mergeDiffs(diffByFiles, maxDiffLength);
|
||||
|
||||
const commitMessagePromises = [];
|
||||
const commitMessagePromises = [] as Promise<string | undefined>[];
|
||||
|
||||
for (const fileDiff of mergedFilesDiffs) {
|
||||
if (tokenCount(fileDiff) >= maxDiffLength) {
|
||||
|
||||
+19
-13
@@ -1,26 +1,32 @@
|
||||
import { AiEngine } from '../engine/Engine';
|
||||
import { api } from '../engine/openAi';
|
||||
import { OpenAi } from '../engine/openAi';
|
||||
import { Gemini } from '../engine/gemini';
|
||||
import { getConfig } from '../commands/config';
|
||||
import { ollamaAi } from '../engine/ollama';
|
||||
import { azure } from '../engine/azure';
|
||||
import { anthropicAi } from '../engine/anthropic'
|
||||
import { testAi } from '../engine/testAi';
|
||||
import { OllamaAi } from '../engine/ollama';
|
||||
import { AnthropicAi } from '../engine/anthropic'
|
||||
import { TestAi } from '../engine/testAi';
|
||||
import { Azure } from '../engine/azure';
|
||||
|
||||
export function getEngine(): AiEngine {
|
||||
const config = getConfig();
|
||||
const provider = config?.OCO_AI_PROVIDER;
|
||||
|
||||
if (provider?.startsWith('ollama')) {
|
||||
const ollamaAi = new OllamaAi();
|
||||
const model = provider.split('/')[1];
|
||||
if (model) ollamaAi.setModel(model);
|
||||
|
||||
return ollamaAi;
|
||||
} else if (config?.OCO_AI_PROVIDER == 'anthropic') {
|
||||
return anthropicAi;
|
||||
} else if (config?.OCO_AI_PROVIDER == 'test') {
|
||||
return testAi;
|
||||
} else if (config?.OCO_AI_PROVIDER == 'azure') {
|
||||
return azure;
|
||||
} else if (provider == 'anthropic') {
|
||||
return new AnthropicAi();
|
||||
} else if (provider == 'test') {
|
||||
return new TestAi();
|
||||
} else if (provider == 'gemini') {
|
||||
return new Gemini();
|
||||
} else if (provider == 'azure') {
|
||||
return new Azure();
|
||||
}
|
||||
// open ai gpt by default
|
||||
return api;
|
||||
|
||||
//open ai gpt by default
|
||||
return new OpenAi();
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import { prepareEnvironment } from './utils';
|
||||
|
||||
it('cli flow when there are no changes', async () => {
|
||||
const { gitDir, cleanup } = await prepareEnvironment();
|
||||
|
||||
const { findByText } = await render(`OCO_AI_PROVIDER='test' node`, [resolve('./out/cli.cjs')], { cwd: gitDir });
|
||||
expect(await findByText('No changes detected')).toBeInTheConsole();
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ it('cli flow to generate commit message for 1 new file (staged)', async () => {
|
||||
await render('git' ,['add index.ts'], { cwd: gitDir });
|
||||
|
||||
const { queryByText, findByText, userEvent } = await render(`OCO_AI_PROVIDER='test' node`, [resolve('./out/cli.cjs')], { cwd: gitDir });
|
||||
|
||||
expect(await queryByText('No files are staged')).not.toBeInTheConsole();
|
||||
expect(await queryByText('Do you want to stage all files and generate commit message?')).not.toBeInTheConsole();
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import 'cli-testing-library/extend-expect'
|
||||
import { configure } from 'cli-testing-library'
|
||||
import { jest } from '@jest/globals';
|
||||
|
||||
global.jest = jest;
|
||||
|
||||
/**
|
||||
* Adjusted the wait time for waitFor/findByText to 2000ms, because the default 1000ms makes the test results flaky
|
||||
|
||||
@@ -55,7 +55,7 @@ OCO_ONE_LINE_COMMIT="true"
|
||||
expect(config!['OCO_LANGUAGE']).toEqual('de');
|
||||
expect(config!['OCO_MESSAGE_TEMPLATE_PLACEHOLDER']).toEqual('$m');
|
||||
expect(config!['OCO_PROMPT_MODULE']).toEqual('@commitlint');
|
||||
expect(config!['OCO_AI_PROVIDER']).toEqual('ollama');
|
||||
expect(() => ['ollama', 'gemini'].includes(config!['OCO_AI_PROVIDER'])).toBeTruthy();
|
||||
expect(config!['OCO_GITPUSH']).toEqual(false);
|
||||
expect(config!['OCO_ONE_LINE_COMMIT']).toEqual(true);
|
||||
|
||||
@@ -96,7 +96,7 @@ OCO_ONE_LINE_COMMIT="true"
|
||||
expect(config!['OCO_LANGUAGE']).toEqual('de');
|
||||
expect(config!['OCO_MESSAGE_TEMPLATE_PLACEHOLDER']).toEqual('$m');
|
||||
expect(config!['OCO_PROMPT_MODULE']).toEqual('@commitlint');
|
||||
expect(config!['OCO_AI_PROVIDER']).toEqual('ollama');
|
||||
expect(() => ['ollama', 'gemini'].includes(config!['OCO_AI_PROVIDER'])).toBeTruthy();
|
||||
expect(config!['OCO_GITPUSH']).toEqual(false);
|
||||
expect(config!['OCO_ONE_LINE_COMMIT']).toEqual(true);
|
||||
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
import { Gemini } from '../../src/engine/gemini';
|
||||
import { ChatCompletionRequestMessage } from 'openai';
|
||||
import { GenerativeModel, GoogleGenerativeAI } from '@google/generative-ai';
|
||||
import { ConfigType, getConfig } from '../../src/commands/config';
|
||||
|
||||
describe('Gemini', () => {
|
||||
let gemini: Gemini;
|
||||
let mockConfig: ConfigType;
|
||||
let mockGoogleGenerativeAi: GoogleGenerativeAI;
|
||||
let mockGenerativeModel: GenerativeModel;
|
||||
let mockExit: jest.SpyInstance<never, [code?: number | undefined], any>;
|
||||
let mockWarmup: jest.SpyInstance<any, unknown[], any>;
|
||||
|
||||
const noop: (code?: number | undefined) => never = (code?: number | undefined) => {};
|
||||
|
||||
const mockGemini = () => {
|
||||
gemini = new Gemini();
|
||||
}
|
||||
|
||||
const oldEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.resetModules();
|
||||
process.env = { ...oldEnv };
|
||||
|
||||
jest.mock('@google/generative-ai');
|
||||
jest.mock('../src/commands/config');
|
||||
|
||||
jest.mock('@clack/prompts', () => ({
|
||||
intro: jest.fn(),
|
||||
outro: jest.fn(),
|
||||
}));
|
||||
|
||||
if (mockWarmup) mockWarmup.mockRestore();
|
||||
|
||||
mockExit = jest.spyOn(process, 'exit').mockImplementation();
|
||||
mockConfig = getConfig() as ConfigType;
|
||||
|
||||
mockConfig.OCO_AI_PROVIDER = 'gemini';
|
||||
mockConfig.OCO_GEMINI_API_KEY = 'mock-api-key';
|
||||
mockConfig.OCO_MODEL = 'gemini-1.5-flash';
|
||||
|
||||
mockGoogleGenerativeAi = new GoogleGenerativeAI(mockConfig.OCO_GEMINI_API_KEY);
|
||||
mockGenerativeModel = mockGoogleGenerativeAi.getGenerativeModel({ model: mockConfig.OCO_MODEL, });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
gemini = undefined as any;
|
||||
})
|
||||
|
||||
afterAll(() => {
|
||||
mockExit.mockRestore();
|
||||
process.env = oldEnv;
|
||||
});
|
||||
|
||||
it('should initialize with correct config', () => {
|
||||
mockGemini();
|
||||
// gemini = new Gemini();
|
||||
expect(gemini).toBeDefined();
|
||||
});
|
||||
|
||||
it('should warmup correctly', () => {
|
||||
mockWarmup = jest.spyOn(Gemini.prototype as any, 'warmup').mockImplementation(noop);
|
||||
mockGemini();
|
||||
expect(gemini).toBeDefined();
|
||||
});
|
||||
|
||||
it('should exit process if OCO_GEMINI_API_KEY is not set and command is not config', () => {
|
||||
process.env.OCO_GEMINI_API_KEY = undefined;
|
||||
process.env.OCO_AI_PROVIDER = 'gemini';
|
||||
|
||||
mockGemini();
|
||||
|
||||
expect(mockExit).toHaveBeenCalledWith(1);
|
||||
});
|
||||
|
||||
it('should exit process if model is not supported and command is not config', () => {
|
||||
process.env.OCO_GEMINI_API_KEY = undefined;
|
||||
process.env.OCO_AI_PROVIDER = 'gemini';
|
||||
|
||||
mockGemini();
|
||||
|
||||
expect(mockExit).toHaveBeenCalledWith(1);
|
||||
});
|
||||
|
||||
it('should generate commit message', async () => {
|
||||
const mockGenerateContent = jest.fn().mockResolvedValue({ response: { text: () => 'generated content' } });
|
||||
mockGenerativeModel.generateContent = mockGenerateContent;
|
||||
|
||||
mockWarmup = jest.spyOn(Gemini.prototype as any, 'warmup').mockImplementation(noop);
|
||||
mockGemini();
|
||||
|
||||
const messages: ChatCompletionRequestMessage[] = [
|
||||
{ role: 'system', content: 'system message' },
|
||||
{ role: 'assistant', content: 'assistant message' },
|
||||
];
|
||||
|
||||
jest.spyOn(gemini, 'generateCommitMessage').mockImplementation(async () => 'generated content');
|
||||
const result = await gemini.generateCommitMessage(messages);
|
||||
|
||||
expect(result).toEqual('generated content');
|
||||
expect(mockWarmup).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
});
|
||||
+5
-5
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ESNext",
|
||||
"lib": ["ES5", "ES6"],
|
||||
"target": "ES2020",
|
||||
"lib": ["ES6", "ES2020"],
|
||||
|
||||
"module": "ESNext",
|
||||
// "rootDir": "./src",
|
||||
"module": "CommonJS",
|
||||
|
||||
"resolveJsonModule": true,
|
||||
"moduleResolution": "node",
|
||||
"moduleResolution": "Node",
|
||||
|
||||
"allowJs": true,
|
||||
|
||||
|
||||
Reference in New Issue
Block a user