feat(engine): add support for MLX AI provider (#437)

* docs(CONTRIBUTING.md): update `TODO.md` reference (#435)

Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>

* feat(engine): add support for MLX AI provider
docs/engine: update documentation to include new engine providers

* fix(mlx.ts): add repetition_penalty option to generateCommitMessage method for improved model behavior

---------

Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
Co-authored-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
This commit is contained in:
albi ⚡️
2024-12-09 11:02:38 +01:00
committed by GitHub
parent dd65b9c3e3
commit 26ebfb416d
7 changed files with 145 additions and 16 deletions
+39 -3
View File
@@ -48745,6 +48745,8 @@ var getDefaultModel = (provider) => {
switch (provider) {
case "ollama":
return "";
case "mlx":
return "";
case "anthropic":
return MODEL_LIST.anthropic[0];
case "gemini":
@@ -48776,7 +48778,7 @@ var configValidators = {
validateConfig(
"OCO_API_KEY",
value,
'You need to provide the OCO_API_KEY when OCO_AI_PROVIDER set to "openai" (default) or "ollama" or "azure" or "gemini" or "flowise" or "anthropic". Run `oco config set OCO_API_KEY=your_key OCO_AI_PROVIDER=openai`'
'You need to provide the OCO_API_KEY when OCO_AI_PROVIDER set to "openai" (default) or "ollama" or "mlx" or "azure" or "gemini" or "flowise" or "anthropic". Run `oco config set OCO_API_KEY=your_key OCO_AI_PROVIDER=openai`'
);
return value;
},
@@ -48882,8 +48884,8 @@ var configValidators = {
"test",
"flowise",
"groq"
].includes(value) || value.startsWith("ollama"),
`${value} is not supported yet, use 'ollama', 'anthropic', 'azure', 'gemini', 'flowise' or 'openai' (default)`
].includes(value) || value.startsWith("ollama") || value.startsWith("mlx"),
`${value} is not supported yet, use 'ollama', 'mlx', anthropic', 'azure', 'gemini', 'flowise' or 'openai' (default)`
);
return value;
},
@@ -63325,6 +63327,38 @@ var GroqEngine = class extends OpenAiEngine {
}
};
// src/engine/mlx.ts
var MLXEngine = class {
constructor(config6) {
this.config = config6;
this.client = axios_default.create({
url: config6.baseURL ? `${config6.baseURL}/${config6.apiKey}` : "http://localhost:8080/v1/chat/completions",
headers: { "Content-Type": "application/json" }
});
}
async generateCommitMessage(messages) {
const params = {
messages,
temperature: 0,
top_p: 0.1,
repetition_penalty: 1.5,
stream: false
};
try {
const response = await this.client.post(
this.client.getUri(this.config),
params
);
const choices = response.data.choices;
const message = choices[0].message;
return message?.content;
} catch (err) {
const message = err.response?.data?.error ?? err.message;
throw new Error(`MLX provider error: ${message}`);
}
}
};
// src/utils/engine.ts
function getEngine() {
const config6 = getConfig();
@@ -63351,6 +63385,8 @@ function getEngine() {
return new FlowiseEngine(DEFAULT_CONFIG2);
case "groq" /* GROQ */:
return new GroqEngine(DEFAULT_CONFIG2);
case "mlx" /* MLX */:
return new MLXEngine(DEFAULT_CONFIG2);
default:
return new OpenAiEngine(DEFAULT_CONFIG2);
}