feat(config, engine): add support for Mistral AI provider and engine (#436)
* docs(CONTRIBUTING.md): update `TODO.md` reference (#435) Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com> * feat(config, engine): add support for Mistral AI provider and engine * ``` feat(package): add mistralai and zod dependencies ``` * fix: recreate package-lock.json with node20 * fix: recreate package-lock.json with node v20.18.1 based on branch dev --------- Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com> Co-authored-by: Emmanuel Ferdman <emmanuelferdman@gmail.com> Co-authored-by: pedro-valentim <>
This commit is contained in:
committed by
GitHub
parent
26ebfb416d
commit
206887612a
+48
-2
@@ -86,6 +86,48 @@ export const MODEL_LIST = {
|
||||
'llama-3.1-70b-versatile', // Llama 3.1 70B (Preview)
|
||||
'gemma-7b-it', // Gemma 7B
|
||||
'gemma2-9b-it' // Gemma 2 9B
|
||||
],
|
||||
|
||||
mistral: [
|
||||
'ministral-3b-2410',
|
||||
'ministral-3b-latest',
|
||||
'ministral-8b-2410',
|
||||
'ministral-8b-latest',
|
||||
'open-mistral-7b',
|
||||
'mistral-tiny',
|
||||
'mistral-tiny-2312',
|
||||
'open-mistral-nemo',
|
||||
'open-mistral-nemo-2407',
|
||||
'mistral-tiny-2407',
|
||||
'mistral-tiny-latest',
|
||||
'open-mixtral-8x7b',
|
||||
'mistral-small',
|
||||
'mistral-small-2312',
|
||||
'open-mixtral-8x22b',
|
||||
'open-mixtral-8x22b-2404',
|
||||
'mistral-small-2402',
|
||||
'mistral-small-2409',
|
||||
'mistral-small-latest',
|
||||
'mistral-medium-2312',
|
||||
'mistral-medium',
|
||||
'mistral-medium-latest',
|
||||
'mistral-large-2402',
|
||||
'mistral-large-2407',
|
||||
'mistral-large-2411',
|
||||
'mistral-large-latest',
|
||||
'pixtral-large-2411',
|
||||
'pixtral-large-latest',
|
||||
'codestral-2405',
|
||||
'codestral-latest',
|
||||
'codestral-mamba-2407',
|
||||
'open-codestral-mamba',
|
||||
'codestral-mamba-latest',
|
||||
'pixtral-12b-2409',
|
||||
'pixtral-12b',
|
||||
'pixtral-12b-latest',
|
||||
'mistral-embed',
|
||||
'mistral-moderation-2411',
|
||||
'mistral-moderation-latest',
|
||||
]
|
||||
};
|
||||
|
||||
@@ -101,6 +143,8 @@ const getDefaultModel = (provider: string | undefined): string => {
|
||||
return MODEL_LIST.gemini[0];
|
||||
case 'groq':
|
||||
return MODEL_LIST.groq[0];
|
||||
case 'mistral':
|
||||
return MODEL_LIST.mistral[0];
|
||||
default:
|
||||
return MODEL_LIST.openai[0];
|
||||
}
|
||||
@@ -257,14 +301,15 @@ export const configValidators = {
|
||||
CONFIG_KEYS.OCO_AI_PROVIDER,
|
||||
[
|
||||
'openai',
|
||||
'mistral',
|
||||
'anthropic',
|
||||
'gemini',
|
||||
'azure',
|
||||
'test',
|
||||
'flowise',
|
||||
'groq'
|
||||
].includes(value) || value.startsWith('ollama') || value.startsWith('mlx'),
|
||||
`${value} is not supported yet, use 'ollama', 'mlx', anthropic', 'azure', 'gemini', 'flowise' or 'openai' (default)`
|
||||
].includes(value) || value.startsWith('ollama'),
|
||||
`${value} is not supported yet, use 'ollama', 'mlx', 'anthropic', 'azure', 'gemini', 'flowise', 'mistral' or 'openai' (default)`
|
||||
);
|
||||
|
||||
return value;
|
||||
@@ -310,6 +355,7 @@ export enum OCO_AI_PROVIDER_ENUM {
|
||||
TEST = 'test',
|
||||
FLOWISE = 'flowise',
|
||||
GROQ = 'groq',
|
||||
MISTRAL = 'mistral',
|
||||
MLX = 'mlx'
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import { OpenAIClient as AzureOpenAIClient } from '@azure/openai';
|
||||
import { GoogleGenerativeAI as GeminiClient } from '@google/generative-ai';
|
||||
import { AxiosInstance as RawAxiosClient } from 'axios';
|
||||
import { OpenAI as OpenAIClient } from 'openai';
|
||||
import { Mistral as MistralClient } from '@mistralai/mistralai';
|
||||
|
||||
export interface AiEngineConfig {
|
||||
apiKey: string;
|
||||
@@ -17,7 +18,8 @@ type Client =
|
||||
| AzureOpenAIClient
|
||||
| AnthropicClient
|
||||
| RawAxiosClient
|
||||
| GeminiClient;
|
||||
| GeminiClient
|
||||
| MistralClient;
|
||||
|
||||
export interface AiEngine {
|
||||
config: AiEngineConfig;
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
import axios from 'axios';
|
||||
import { Mistral } from '@mistralai/mistralai';
|
||||
import { OpenAI } from 'openai';
|
||||
import { GenerateCommitMessageErrorEnum } from '../generateCommitMessageFromGitDiff';
|
||||
import { tokenCount } from '../utils/tokenCount';
|
||||
import { AiEngine, AiEngineConfig } from './Engine';
|
||||
import {
|
||||
AssistantMessage as MistralAssistantMessage,
|
||||
SystemMessage as MistralSystemMessage,
|
||||
ToolMessage as MistralToolMessage,
|
||||
UserMessage as MistralUserMessage
|
||||
} from '@mistralai/mistralai/models/components';
|
||||
|
||||
export interface MistralAiConfig extends AiEngineConfig {}
|
||||
export type MistralCompletionMessageParam = Array<
|
||||
| (MistralSystemMessage & { role: "system" })
|
||||
| (MistralUserMessage & { role: "user" })
|
||||
| (MistralAssistantMessage & { role: "assistant" })
|
||||
| (MistralToolMessage & { role: "tool" })
|
||||
>
|
||||
|
||||
export class MistralAiEngine implements AiEngine {
|
||||
config: MistralAiConfig;
|
||||
client: Mistral;
|
||||
|
||||
constructor(config: MistralAiConfig) {
|
||||
this.config = config;
|
||||
|
||||
if (!config.baseURL) {
|
||||
this.client = new Mistral({ apiKey: config.apiKey });
|
||||
} else {
|
||||
this.client = new Mistral({ apiKey: config.apiKey, serverURL: config.baseURL });
|
||||
}
|
||||
}
|
||||
|
||||
public generateCommitMessage = async (
|
||||
messages: Array<OpenAI.Chat.Completions.ChatCompletionMessageParam>
|
||||
): Promise<string | null> => {
|
||||
const params = {
|
||||
model: this.config.model,
|
||||
messages: messages as MistralCompletionMessageParam,
|
||||
topP: 0.1,
|
||||
maxTokens: this.config.maxTokensOutput
|
||||
};
|
||||
|
||||
try {
|
||||
const REQUEST_TOKENS = messages
|
||||
.map((msg) => tokenCount(msg.content as string) + 4)
|
||||
.reduce((a, b) => a + b, 0);
|
||||
|
||||
if (
|
||||
REQUEST_TOKENS >
|
||||
this.config.maxTokensInput - this.config.maxTokensOutput
|
||||
)
|
||||
throw new Error(GenerateCommitMessageErrorEnum.tooMuchTokens);
|
||||
|
||||
const completion = await this.client.chat.complete(params);
|
||||
|
||||
if (!completion.choices)
|
||||
throw Error('No completion choice available.')
|
||||
|
||||
const message = completion.choices[0].message;
|
||||
|
||||
if (!message || !message.content)
|
||||
throw Error('No completion choice available.')
|
||||
|
||||
return message.content as string;
|
||||
} catch (error) {
|
||||
const err = error as Error;
|
||||
if (
|
||||
axios.isAxiosError<{ error?: { message: string } }>(error) &&
|
||||
error.response?.status === 401
|
||||
) {
|
||||
const mistralError = error.response.data.error;
|
||||
|
||||
if (mistralError) throw new Error(mistralError.message);
|
||||
}
|
||||
|
||||
throw err;
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import { FlowiseEngine } from '../engine/flowise';
|
||||
import { GeminiEngine } from '../engine/gemini';
|
||||
import { OllamaEngine } from '../engine/ollama';
|
||||
import { OpenAiEngine } from '../engine/openAi';
|
||||
import { MistralAiEngine } from '../engine/mistral';
|
||||
import { TestAi, TestMockType } from '../engine/testAi';
|
||||
import { GroqEngine } from '../engine/groq';
|
||||
import { MLXEngine } from '../engine/mlx';
|
||||
@@ -44,6 +45,9 @@ export function getEngine(): AiEngine {
|
||||
case OCO_AI_PROVIDER_ENUM.GROQ:
|
||||
return new GroqEngine(DEFAULT_CONFIG);
|
||||
|
||||
case OCO_AI_PROVIDER_ENUM.MISTRAL:
|
||||
return new MistralAiEngine(DEFAULT_CONFIG);
|
||||
|
||||
case OCO_AI_PROVIDER_ENUM.MLX:
|
||||
return new MLXEngine(DEFAULT_CONFIG);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user