news 2026/2/23 13:41:31

在你的苹果硅 MacBook 上使用 LoRA 进行微调

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
在你的苹果硅 MacBook 上使用 LoRA 进行微调

原文:towardsdatascience.com/lora-fine-tuning-on-your-apple-silicon-macbook-432c7dab614a

随着模型变得越来越小,我们看到越来越多的消费级电脑能够本地运行 LLM(大型语言模型)。这不仅大大降低了人们训练自己模型的技术门槛,还允许尝试更多的训练技术。

一款能够很好地本地运行 LLM 的消费级电脑是苹果 Mac。苹果利用其定制的硅芯片,创建了一个数组处理库,称为 MLX。通过使用 MLX,苹果可以比许多其他消费级电脑更好地运行 LLM。

在这篇博客文章中,我将从高层次上解释 MLX 是如何工作的,然后向你展示如何使用 MLX 在本地微调你自己的 LLM。最后,我们将通过量化来加速我们的微调模型。

让我们开始吧!

MLX 背景

MLX 是什么(以及谁可以使用它?)

MLX 是苹果公司的一个开源库,它让 Mac 用户能够更高效地运行包含大量张量(tensors)的程序。自然地,当我们想要训练或微调一个模型时,这个库就派上用场了。

MLX 的工作方式是通过在中央处理单元(CPU)、图形处理单元(GPU)和内存管理单元(MMU)之间进行高效的内存传输。对于每一种系统架构,最耗时的工作是在你将内存移动到寄存器之间时。在 Nvidia GPU 上,它们通过在设备上创建大量的 SRAM 来最小化内存传输。对于苹果来说,他们设计了他们的硅芯片,使得 GPU 和 CPU 可以通过 MMU 访问相同的内存。因此,GPU 在对其数据进行操作之前不需要将其数据加载到其内存中。这种架构被称为系统级芯片(SOC),通常需要你内部构建芯片,而不是组合其他制造商预制的部件。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/0b5aa7801a06e97b4f7aa8f97f99c389.png

Image by Author – SOC 与标准内存访问模式

由于苹果现在设计了自己的硅芯片,它可以编写低级软件,从而高效地利用它。然而,这也意味着使用英特尔处理器的 Mac 用户将无法使用这个库。

安装 MLX

一旦你拥有了一台苹果硅电脑,我们就有几种方法可以安装 MLX。我将向你展示如何使用 python 虚拟环境,但请注意,你也可以通过单独的环境管理器如 conda 来安装它。

在我们的终端中,我们首先创建一个名为venv的虚拟环境,然后进入它。

python-m venv venv;source./venv/bin/activate

现在我们已经设置了环境,我们将使用 pip 来安装:

pip install mlx

运行简单的推理

在本地设置好我们的库后,让我们选择一个将要运行的模型。我喜欢使用 Phi 系列模型,因为与其他模型相比,它们相当小(3B 参数比 7B),但性能仍然相当不错。

我们可以使用相同的终端命令下载模型并进行推理:

python-m mlx_lm.generate--model microsoft/Phi-3.5-mini-instruct--prompt"Who was the first president?"--max-tokens4096

为了解释我们的命令,我们使用内置的mlx_lm函数让我们的库知道我们将使用语言模型进行推理。我们传入我们使用的模型,其名称在HuggingFace(Phi-3 在 HuggingFace 上以这种方式出现)。我们传入我们允许的最大令牌数,然后最终传入提示。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/fbddb36faf64ac51243592643ea6cc73.png

作者提供的图片 - "谁是第一位总统"的基础模型

微调

生成微调数据集

为了使我们的示例简单但有用,我们将微调模型,使其始终以以下模式以 JSON 格式响应:

{"context":"...","question":"...","answer":"..."}

要使用 MLX 进行微调,我们需要我们的数据集以 MLX 理解的模式。有 4 种格式:chattoolscompletionstext。我们将专注于completions,这样当我们对模型进行提示时,它将以 JSON 格式返回答案。Completions 要求我们的训练数据使用以下模式:

{"prompt":"...","completion":"...",}

现在我们已经了解了如何将数据传递给 MLX,我们需要找到一个好的微调数据集。我创建以下 Python 脚本来处理squad_v2数据集,使其符合 MLX 所需的模式。

fromdatasetsimportload_datasetimportjsonimportrandomprint("Loading dataset and tokenizer...")qa_dataset=load_dataset("squad_v2")defcreate_completion(context,question,answer):iflen(answer["text"])<1:answer_text="I Don't Know"else:answer_text=answer["text"][0]completion_template={"context":context,"question":question,"answer":answer_text}returnjson.dumps(completion_template)defprocess_dataset(dataset):processed_data=[]forsampleindataset:completion=create_completion(sample['context'],sample['question'],sample['answers'])prompt=sample['question']processed_data.append({"prompt":prompt,"completion":completion})returnprocessed_dataprint("Processing training data...")train_data=process_dataset(qa_dataset['train'])print("Processing validation data...")valid_data=process_dataset(qa_dataset['validation'])# SQuAD v2 uses 'validation' as test set# Combine all data for redistributionall_data=train_data+valid_data random.shuffle(all_data)# Calculate new split sizestotal_size=len(all_data)train_size=int(0.8*total_size)test_size=int(0.1*total_size)valid_size=total_size-train_size-test_size# Split the datanew_train_data=all_data[:train_size]new_test_data=all_data[train_size:train_size+test_size]new_valid_data=all_data[train_size+test_size:]# Write to JSONL filesdefwrite_jsonl(data,filename):withopen(filename,'w')asf:foritemindata:f.write(json.dumps(item)+'n')print("Writing train.jsonl...")folder_prefix="./data/"write_jsonl(new_train_data,folder_prefix+'train.jsonl')print("Writing test.jsonl...")write_jsonl(new_test_data,folder_prefix+'test.jsonl')print("Writing valid.jsonl...")write_jsonl(new_valid_data,folder_prefix+'valid.jsonl')print(f"Dataset split and saved: train ({len(new_train_data)}), test ({len(new_test_data)}), valid ({len(new_valid_data)})")# Verify file contentsdefcount_lines(filename):withopen(folder_prefix+filename,'r')asf:returnsum(1for_inf)print("nVerifying file contents:")print(f"train.jsonl:{count_lines('train.jsonl')}lines")print(f"test.jsonl:{count_lines('test.jsonl')}lines")print(f"valid.jsonl:{count_lines('valid.jsonl')}lines")

重要的是,在squad_v2数据集中,我们有示例,其中答案未知,我们明确告诉它写“我不知道”。这有助于通过向模型展示在给定上下文不知道答案时应该做什么来减少幻觉。

在这一步骤结束时,我们现在有一个如下所示的数据集,分为训练、测试和验证文件:

{"prompt":"...","completion":"{"context": "...","question":"...","answer":"..."}"}

LoRA 微调

为了进行微调,我们将使用 MLX 内置的 LoRA 函数。要了解更多关于 LoRA 背后的数学和理论,请查看我的博客文章。

python-m mlx_lm.lora--model microsoft/Phi-3.5-mini-instruct--train--data./data--iters100

运行此脚本,我们发现可以实现最终验证损失为 1.530,考虑到我们只更新了模型 0.082%的权重,这并不坏。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/8367a25fe9defd92f51b00029bc98546.png

作者提供的图片 - 微调输出

你会注意到最后我们已将新的 LoRA 权重保存为适配器。适配器保存了我们在微调期间学到的应该对权重进行的更新。我们拥有单独的适配器文件,而不是立即更新模型,因为我们可能有一个糟糕的训练运行或想要为不同的任务保留多个微调。为了给我们提供更多选择,我们通常将基本权重与更新分开存储,直到我们想要通过融合将权重永久化。

微调模型推理

现在我们已经生成了适配器,让我们看看如何在推理期间使用它们以获得更好的输出。我们希望测试输出是否如预期那样。在我们的案例中,我们期望给定一个提示,模型会以我们之前做的 JSON 模式给出我们的答案。

我们再次使用mlx_lm.generate命令,但这次我们传递了额外的参数adapter-path。这告诉 MLX 在哪里找到额外的权重,并确保在推理时使用它们。

python-m mlx_lm.generate--model microsoft/Phi-3.5-mini-instruct--adapter-path./adapters--prompt"Who was the first president?"--max-tokens4096

当我们运行上述命令时,我们看到我们得到了一个 JSON 格式的响应,其中包含了我们微调时指定的键。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/3c29204bbf8bf4dcc9a6aba84c767850.png

作者提供的图片 – “谁是第一位总统”的微调模型输出

向 LoRA 传递更多参数

我们很幸运,我们的第一次运行使模型很好地遵循了我们的格式。如果我们遇到了更多问题,我们就会想要为 LoRA 指定更多的参数来考虑。为此,你创建一个lora_config.yaml文件,并将其传递给 LoRA 命令,如下所示。在此处查看示例 yaml 配置文件。

python-m mlx_lm.lora--config<path_to_file>

量化

什么是量化?

从上面的运行中,我们可以看到模型使用了大量的资源。生成每个标记需要大约 17 秒,并在峰值时使用了大约 7GB 的内存。虽然在某些情况下推理一个大模型是有意义的,但对我们来说,我们希望以最少的成本运行本地 LLM。因此,我们希望模型使用更少的内存并运行得更快。在不改变模型架构的情况下,我们可以通过量化来优化这里。

要理解量化,让我先解释一下我们如何存储模型的参数。每个参数都是一个数字,通常在科学计算中,我们使用浮点表示来确保我们的计算尽可能准确(要了解更多关于这里的确切布局,请查看我的博客)。然而,正如你所看到的,这需要大量的位来表示每个数字。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/39df7a9cb33b2e313d64d1e9c559c435.png

作者提供的图片 – 浮点 32 位(FP32)的 IEEE 表示

由于我们倾向于使用数十亿个参数,每个参数的大小对模型的总体内存占用有显著影响。此外,浮点运算通常比整数运算需要更多的计算资源。正是这两方面的压力促使人们尝试使用新的数据类型来存储参数。当我们对模型进行量化时,我们可以从使用浮点数转换为使用整数。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/d481c40b658ab3eb7184e8e31dea38ae.png

作者提供的图片 – 8 位整数表示

这里的权衡是,我们能够更快地进行计算并使用更少的内存,但我们的性能往往会随着参数值的不精确而下降。这里的艺术在于,在加快速度和使模型占用更少空间的同时,尽可能保持基模型的性能。

量化我们的模型

要量化我们的模型,我们运行以下命令:

python-m mlx_lm.convert--hf-path microsoft/Phi-3-mini-4k-instruct-q--q-bits4

我们通过传递-q标志来告诉模型进行量化,然后使用--q-bits标志指定每个权重的位数。

完成后,它将在本地创建一个名为mlx_model的文件夹,用于存储我们新的量化模型。它将把存储在 HuggingFace 中的所有权重转换为用 4 位表示的整数(这是最大的减少之一)。

QLoRA

现在我们有了量化模型,我们可以使用与运行 LoRA 相同的训练数据和命令在它上面运行 QLoRA。MLX 足够智能,能够看到如果权重被量化,它应该切换到使用 QLoRA。

我们终端命令看起来和之前几乎一样,但这次我们告诉它使用本地已有的量化模型作为源,而不是 hugging face 上的模型。

python-m mlx_lm.lora--model./mlx_model--train--data./data--iters100

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/887c72c5176df90c29ce3b86a0ab98fc.png

作者提供的图片 – QLoRA 运行过程中的微调输出

现在我们可以推理我们的 QLoRA 微调模型并进行比较:

python-m mlx_lm.generate--model./mlx_model--adapter-path./adapters--prompt"Who was the first president?"--max-tokens4096

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/b9c7b2dbd63012ad07745c820d91b5b0.png

作者提供的图片 – “谁是第一位总统”的微调模型输出

与原始微调相比,我们可以看到内存使用量显著降低,每秒生成的 token 数量也显著提高。当我们将其发送给用户时,他们肯定会注意到速度更快。为了确定质量,我们不得不比较函数之间的损失。

对于 LoRA 模型,我们最后的验证损失是 1.530,而 QLoRA 模型的损失是 1.544。虽然预期 LoRA 模型的损失会更小,但 QLoRA 模型并没有相差太远,这意味着我们做得相当不错!

结束语

最后,这篇博客向您展示了如何使用 Mac 和 MLX 在本地微调您自己的 LLM。随着越来越多的计算能力被引入消费硬件,我们可以期待更多的训练技术成为可能。这可以为 ML 打开更多的用例,并帮助我们解决更多的问题。

要查看此博客使用的完整代码,请查看下面的 GitHub 仓库:

GitHub – matthewjgunton/mlx_json_lora

现在是构建模型的好时机!


[1] Hannun, A.,等人,“mlx” (2024),Github

[2] Lo, K.,等人,“Phi-3CookBook” (2024),Github

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/2/22 21:38:09

资源嗅探工具实操指南:3大核心场景+7个实战技巧

资源嗅探工具实操指南&#xff1a;3大核心场景7个实战技巧 【免费下载链接】cat-catch 猫抓 chrome资源嗅探扩展 项目地址: https://gitcode.com/GitHub_Trending/ca/cat-catch 你是否曾遇到想保存网页视频却找不到下载按钮的尴尬&#xff1f;是否在社交媒体看到精彩内容…

作者头像 李华
网站建设 2026/2/23 9:19:39

如何实现应用级位置隔离?安卓虚拟定位工具深度技术指南

如何实现应用级位置隔离&#xff1f;安卓虚拟定位工具深度技术指南 【免费下载链接】FakeLocation Xposed module to mock locations per app. 项目地址: https://gitcode.com/gh_mirrors/fak/FakeLocation 在移动互联网时代&#xff0c;应用级定位控制已成为保护隐私与…

作者头像 李华
网站建设 2026/2/22 4:09:09

解锁QQ音乐加密文件:跨设备自由播放的终极解决方案

解锁QQ音乐加密文件&#xff1a;跨设备自由播放的终极解决方案 【免费下载链接】qmcdump 一个简单的QQ音乐解码&#xff08;qmcflac/qmc0/qmc3 转 flac/mp3&#xff09;&#xff0c;仅为个人学习参考用。 项目地址: https://gitcode.com/gh_mirrors/qm/qmcdump 还在为QQ…

作者头像 李华
网站建设 2026/2/22 12:25:26

SiameseUIE中文-base效果实测:中文OCR后文本的噪声鲁棒性抽取能力

SiameseUIE中文-base效果实测&#xff1a;中文OCR后文本的噪声鲁棒性抽取能力 1. 为什么OCR后的文本特别考验信息抽取模型&#xff1f; 你有没有遇到过这样的情况&#xff1a;扫描合同、截图发票、翻拍古籍&#xff0c;再用OCR工具转成文字&#xff0c;结果满屏都是错别字、漏…

作者头像 李华
网站建设 2026/2/20 16:41:10

3分钟搞定网页资源下载:告别99%的媒体保存难题

3分钟搞定网页资源下载&#xff1a;告别99%的媒体保存难题 【免费下载链接】cat-catch 猫抓 chrome资源嗅探扩展 项目地址: https://gitcode.com/GitHub_Trending/ca/cat-catch 你是否曾遇到这样的困境&#xff1a;精心挑选的在线课程视频无法保存、设计师需要的高清素材…

作者头像 李华