feat(engine): add support for MLX AI provider (#437)

* docs(CONTRIBUTING.md): update `TODO.md` reference (#435)

Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>

* feat(engine): add support for MLX AI provider
docs/engine: update documentation to include new engine providers

* fix(mlx.ts): add repetition_penalty option to generateCommitMessage method for improved model behavior

---------

Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
Co-authored-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
This commit is contained in:
albi ⚡️
2024-12-09 11:02:38 +01:00
committed by GitHub
parent dd65b9c3e3
commit 26ebfb416d
7 changed files with 145 additions and 16 deletions
+45 -7
View File
@@ -431,8 +431,8 @@ var require_escape = __commonJS({
}
function escapeArgument(arg, doubleEscapeMetaChars) {
arg = `${arg}`;
arg = arg.replace(/(\\*)"/g, '$1$1\\"');
arg = arg.replace(/(\\*)$/, "$1$1");
arg = arg.replace(/(?=(\\+?)?)\1"/g, '$1$1\\"');
arg = arg.replace(/(?=(\\+?)?)\1$/, "$1$1");
arg = `"${arg}"`;
arg = arg.replace(metaCharsRegExp, "^$1");
if (doubleEscapeMetaChars) {
@@ -578,7 +578,7 @@ var require_enoent = __commonJS({
const originalEmit = cp.emit;
cp.emit = function(name, arg1) {
if (name === "exit") {
const err = verifyENOENT(arg1, parsed, "spawn");
const err = verifyENOENT(arg1, parsed);
if (err) {
return originalEmit.call(cp, "error", err);
}
@@ -27389,7 +27389,8 @@ var package_default = {
"test:unit:docker": "npm run test:docker-build && DOCKER_CONTENT_TRUST=0 docker run --rm oco-test npm run test:unit",
"test:e2e": "npm run test:e2e:setup && jest test/e2e",
"test:e2e:setup": "sh test/e2e/setup.sh",
"test:e2e:docker": "npm run test:docker-build && DOCKER_CONTENT_TRUST=0 docker run --rm oco-test npm run test:e2e"
"test:e2e:docker": "npm run test:docker-build && DOCKER_CONTENT_TRUST=0 docker run --rm oco-test npm run test:e2e",
"mlx:start": "OCO_AI_PROVIDER='mlx' node ./out/cli.cjs"
},
devDependencies: {
"@commitlint/types": "^17.4.4",
@@ -29933,6 +29934,8 @@ var getDefaultModel = (provider) => {
switch (provider) {
case "ollama":
return "";
case "mlx":
return "";
case "anthropic":
return MODEL_LIST.anthropic[0];
case "gemini":
@@ -29964,7 +29967,7 @@ var configValidators = {
validateConfig(
"OCO_API_KEY",
value,
'You need to provide the OCO_API_KEY when OCO_AI_PROVIDER set to "openai" (default) or "ollama" or "azure" or "gemini" or "flowise" or "anthropic". Run `oco config set OCO_API_KEY=your_key OCO_AI_PROVIDER=openai`'
'You need to provide the OCO_API_KEY when OCO_AI_PROVIDER set to "openai" (default) or "ollama" or "mlx" or "azure" or "gemini" or "flowise" or "anthropic". Run `oco config set OCO_API_KEY=your_key OCO_AI_PROVIDER=openai`'
);
return value;
},
@@ -30070,8 +30073,8 @@ var configValidators = {
"test",
"flowise",
"groq"
].includes(value) || value.startsWith("ollama"),
`${value} is not supported yet, use 'ollama', 'anthropic', 'azure', 'gemini', 'flowise' or 'openai' (default)`
].includes(value) || value.startsWith("ollama") || value.startsWith("mlx"),
`${value} is not supported yet, use 'ollama', 'mlx', anthropic', 'azure', 'gemini', 'flowise' or 'openai' (default)`
);
return value;
},
@@ -30111,6 +30114,7 @@ var OCO_AI_PROVIDER_ENUM = /* @__PURE__ */ ((OCO_AI_PROVIDER_ENUM2) => {
OCO_AI_PROVIDER_ENUM2["TEST"] = "test";
OCO_AI_PROVIDER_ENUM2["FLOWISE"] = "flowise";
OCO_AI_PROVIDER_ENUM2["GROQ"] = "groq";
OCO_AI_PROVIDER_ENUM2["MLX"] = "mlx";
return OCO_AI_PROVIDER_ENUM2;
})(OCO_AI_PROVIDER_ENUM || {});
var defaultConfigPath = (0, import_path.join)((0, import_os.homedir)(), ".opencommit");
@@ -44524,6 +44528,38 @@ var GroqEngine = class extends OpenAiEngine {
}
};
// src/engine/mlx.ts
var MLXEngine = class {
constructor(config7) {
this.config = config7;
this.client = axios_default.create({
url: config7.baseURL ? `${config7.baseURL}/${config7.apiKey}` : "http://localhost:8080/v1/chat/completions",
headers: { "Content-Type": "application/json" }
});
}
async generateCommitMessage(messages) {
const params = {
messages,
temperature: 0,
top_p: 0.1,
repetition_penalty: 1.5,
stream: false
};
try {
const response = await this.client.post(
this.client.getUri(this.config),
params
);
const choices = response.data.choices;
const message = choices[0].message;
return message?.content;
} catch (err) {
const message = err.response?.data?.error ?? err.message;
throw new Error(`MLX provider error: ${message}`);
}
}
};
// src/utils/engine.ts
function getEngine() {
const config7 = getConfig();
@@ -44550,6 +44586,8 @@ function getEngine() {
return new FlowiseEngine(DEFAULT_CONFIG2);
case "groq" /* GROQ */:
return new GroqEngine(DEFAULT_CONFIG2);
case "mlx" /* MLX */:
return new MLXEngine(DEFAULT_CONFIG2);
default:
return new OpenAiEngine(DEFAULT_CONFIG2);
}
+39 -3
View File
@@ -48745,6 +48745,8 @@ var getDefaultModel = (provider) => {
switch (provider) {
case "ollama":
return "";
case "mlx":
return "";
case "anthropic":
return MODEL_LIST.anthropic[0];
case "gemini":
@@ -48776,7 +48778,7 @@ var configValidators = {
validateConfig(
"OCO_API_KEY",
value,
'You need to provide the OCO_API_KEY when OCO_AI_PROVIDER set to "openai" (default) or "ollama" or "azure" or "gemini" or "flowise" or "anthropic". Run `oco config set OCO_API_KEY=your_key OCO_AI_PROVIDER=openai`'
'You need to provide the OCO_API_KEY when OCO_AI_PROVIDER set to "openai" (default) or "ollama" or "mlx" or "azure" or "gemini" or "flowise" or "anthropic". Run `oco config set OCO_API_KEY=your_key OCO_AI_PROVIDER=openai`'
);
return value;
},
@@ -48882,8 +48884,8 @@ var configValidators = {
"test",
"flowise",
"groq"
].includes(value) || value.startsWith("ollama"),
`${value} is not supported yet, use 'ollama', 'anthropic', 'azure', 'gemini', 'flowise' or 'openai' (default)`
].includes(value) || value.startsWith("ollama") || value.startsWith("mlx"),
`${value} is not supported yet, use 'ollama', 'mlx', anthropic', 'azure', 'gemini', 'flowise' or 'openai' (default)`
);
return value;
},
@@ -63325,6 +63327,38 @@ var GroqEngine = class extends OpenAiEngine {
}
};
// src/engine/mlx.ts
var MLXEngine = class {
constructor(config6) {
this.config = config6;
this.client = axios_default.create({
url: config6.baseURL ? `${config6.baseURL}/${config6.apiKey}` : "http://localhost:8080/v1/chat/completions",
headers: { "Content-Type": "application/json" }
});
}
async generateCommitMessage(messages) {
const params = {
messages,
temperature: 0,
top_p: 0.1,
repetition_penalty: 1.5,
stream: false
};
try {
const response = await this.client.post(
this.client.getUri(this.config),
params
);
const choices = response.data.choices;
const message = choices[0].message;
return message?.content;
} catch (err) {
const message = err.response?.data?.error ?? err.message;
throw new Error(`MLX provider error: ${message}`);
}
}
};
// src/utils/engine.ts
function getEngine() {
const config6 = getConfig();
@@ -63351,6 +63385,8 @@ function getEngine() {
return new FlowiseEngine(DEFAULT_CONFIG2);
case "groq" /* GROQ */:
return new GroqEngine(DEFAULT_CONFIG2);
case "mlx" /* MLX */:
return new MLXEngine(DEFAULT_CONFIG2);
default:
return new OpenAiEngine(DEFAULT_CONFIG2);
}