RWKV

RWKV pip 使用指南

以下内容将通过两份推理示例代码,指引你使用 RWKV pip 库。RWKV pip 库的原始代码可以在 ChatRWKV 仓库中找到。

带 CUDA Graph 的加速推理代码

以下是 RWKV 模型的两种推理模式:常规模式(Slow)带 CUDA Graph 的加速模式(Fast/CUDA Graph),后者通过消除 Python 调用开销实现极高的 Token 生成速度。

1. 环境设置与依赖导入

import os, time
import numpy as np
import torch

# 环境变量设置:必须在导入 rwkv 库之前设置
os.environ["RWKV_V7_ON"] = '1'       # 显式开启 RWKV-v7 模型支持(如果你用的是 v6 模型,请注释掉或设为 0)
os.environ['RWKV_JIT_ON'] = '1'      # 开启 JIT (Just-In-Time) 编译,加速算子加载
os.environ["RWKV_CUDA_ON"] = '1'     # 开启 CUDA 自定义算子(需系统安装了 CUDA 编译器,速度显著提升)

from rwkv.model import RWKV
from rwkv.utils import PIPELINE

注意: RWKV_V7_ON 仅适用于 RWKV-v7 架构的模型。如果加载的是 v6 或 v5 模型,请务必将其关闭或删除。

此外,RWKV_CUDA_ON='1' 需要你的环境中配置好了 nvcc 编译器(通常包含在 CUDA Toolkit 中),否则会回退到较慢的 PyTorch 原生实现或报错。

这里设置了运行 RWKV 所需的核心环境变量,并导入了必要的库。

2. 模型加载与参数配置

# 初始化模型
# strategy='cuda fp16': 使用 GPU (cuda) 进行计算,精度为 fp16
model = RWKV(model='/mnt/e/RWKV-Runner/models/rwkv7-g1a-0.1b-20250728-ctx4096', strategy='cuda fp16')

# 初始化 Pipeline
# 用于将文本转换为 token id (encode) 以及将 token id 转换回文本 (decode)
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")

# 生成参数设置
LENGTH_PER_TRIAL = 256  # 每次生成的 token 长度
TEMPERATURE = 1.0       # 温度:越高越随机,越低越确定
TOP_P = 0.0             # Nucleus 采样:0.0 通常意味着贪婪采样(或由库的具体实现决定,RWKV中 0 通常指argmax)

# 提示词 (Prompt)
prompt = "User: simulate SpaceX mars landing using python\n\nAssistant: <think"

这一部分加载了具体的模型权重文件,并设定了生成任务的“超参数”。

strategy='cuda fp16' 是最推荐的配置,如果你显存不足,可以尝试 cuda fp16i8(Int8 量化)。更多配置选项请参考 RWKV pip 库的推理精度与显存需求

3. 常规推理 (Slow Inference)

print('='*80 + '\nSlow inference\n' + '='*80)
print(prompt, end="")

all_tokens = []
out_last = 0

# 1. 预处理 Prompt (Prefill)
# 将提示词编码并输入模型,获取初始的 logits (out) 和状态 (state)
# state 包含了对 prompt 的所有记忆
out, state = model.forward(pipeline.encode(prompt), None)

times = []
all_times = []
t000 = time.perf_counter()

# 2. Token 生成循环
for i in range(LENGTH_PER_TRIAL):
    t00 = time.perf_counter()
    
    # 采样:根据模型输出的 logits 选择下一个 token
    token = pipeline.sample_logits(out, temperature=TEMPERATURE, top_p=TOP_P)
    all_tokens += [token]

    # 解码并打印:将生成的 token 实时转回文本显示
    tmp = pipeline.decode(all_tokens[out_last:])
    if '\ufffd' not in tmp: # 过滤掉无效的 UTF-8 字符(通常是多字节字符未接收完整时出现)
        print(tmp, end="", flush=True) 
        out_last = i+1    

    torch.cuda.synchronize() # 等待 GPU 完成上述操作,确保计时准确
    t0 = time.perf_counter()

    # 模型前向传播 (Forward)
    # 将选出的 token 和上一步的 state 输入模型,计算下一个 token 的概率
    out, state = model.forward(token, state)

    torch.cuda.synchronize()
    t1 = time.perf_counter()
    
    # 记录时间
    times.append(t1 - t0)       # 纯模型计算时间
    all_times.append(t1 - t00)  # 包含采样和 Python 逻辑的总时间

# 计算并打印速度统计
times = np.percentile(times, 50)      # 取中位数
all_times = np.percentile(all_times, 50)
print(f'\n\nToken/s = {round(1/times,2)} (forward), {round(1/all_times,2)} (full)')

这是标准的 PyTorch 逐个 Token 生成流程。

为什么叫 Slow?

在 Python 中,for 循环每一次 model.forward(token, state) 都会发起一次从 CPU 到 GPU 的内核启动调用。对于小模型或快速 GPU,Python 发起调用的时间开销(Overhead)可能比 GPU 实际计算的时间还要长。

4. 准备快速推理:CUDA Graph 录制

print('='*80 + '\nFast inference (CUDAGraph, requires rwkv pip pkg v0.8.31+)\n' + '='*80)
print(prompt, end="")

all_tokens = []
out_last = 0

# 重置状态,获取一个全零的初始状态结构
state = model.generate_zero_state()

#  关键步骤:分配静态显存 (Static Memory) 
# CUDA Graph 要求输入和输出的内存地址必须固定,不能变动。

# 1. 静态输入 Tensor (存放当前 Token 的 Embedding)
static_input = torch.empty((model.n_embd), device="cuda", dtype=torch.half)

# 2. 静态状态 Tensor (存放 RNN 隐状态)
# 需要两组:in 用于输入上一刻状态,out 用于输出当前刻状态
static_state_in = [torch.empty_like(x, device="cuda") for x in state]
static_state_out = [torch.empty_like(x, device="cuda") for x in state]

# 3. 静态输出 Tensor (存放预测下一个词的 Logits)
static_output = torch.empty((model.args.vocab_size), device="cuda", dtype=torch.half)

#  关键步骤:录制图 (Graph Capture) 
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
    # 在这个上下文管理器中,执行一次特殊的 forward 操作。
    # GPU 不会真的计算数据,而是"记住"了所有的计算步骤和依赖关系。
    # model.forward_one_alt 是专门为 CUDAGraph 优化的单步前向函数
    static_output, static_state_out = model.forward_one_alt(static_input, static_state_in)

核心原理: CUDA Graph 将一系列 GPU 操作打包成一个单一的“超级操作”。运行时,只需向预先分配好的 static_ 内存中填入数据,然后告诉 GPU Replay 这个图即可。这个过程绕过了 Python 对每一层算子的调度开销。

5. 执行快速推理:Graph Replay

# 1. 预处理 Prompt (Prefill)
# Prompt 处理阶段无法使用 CUDAGraph(因为长度不固定),所以使用常规 forward
out, state = model.forward(pipeline.encode(prompt), state)

# 2. 将 Prompt 计算后的状态复制到静态内存中,为 Graph 运行做准备
for i in range(len(state)):
    static_state_in[i].copy_(state[i])
    
# 将初始输出也复制进去(虽然如果是 greedy decoding 可能只用 state 就够,但保持一致性)
static_output.copy_(out)

times = []
all_times = []
t000 = time.perf_counter()

# 3. 快速生成循环
for i in range(LENGTH_PER_TRIAL):
    t00 = time.perf_counter()
    
    # 采样 (Sample) - 注意:这里直接从 static_output 读取数据
    token = pipeline.sample_logits(static_output, temperature=TEMPERATURE, top_p=TOP_P)
    all_tokens += [token]

    # 解码打印 (Decode)
    tmp = pipeline.decode(all_tokens[out_last:])
    if '\ufffd' not in tmp:
        print(tmp, end="", flush=True)
        out_last = i+1

    torch.cuda.synchronize()
    t0 = time.perf_counter()

    #  核心加速步骤 
    
    # A. 填充输入:直接从 Embedding 表中取出对应的向量,填入 static_input
    # 注意:这里我们不传 token id,而是直接传向量,减少图内部的查表开销
    static_input.copy_(model.z['emb.weight'][token])
    
    # B. 回放图:一键执行整个模型的前向计算
    g.replay()
    
    # C. 更新状态:将输出状态 (out) 复制回 输入状态 (in),形成 RNN 的循环
    # 这样下一次 replay 时,使用的就是更新后的状态了
    for n in range(len(state)):
        static_state_in[n].copy_(static_state_out[n])
    
    # --
    
    torch.cuda.synchronize()
    t1 = time.perf_counter()
    times.append(t1 - t0)
    all_times.append(t1 - t00)

# 打印统计
times = np.percentile(times, 50)
all_times = np.percentile(all_times, 50)
print(f'\n\nToken/s = {round(1/times,2)} (forward), {round(1/all_times,2)} (full) (note: very inefficient sample_logits)')

这一段展示了如何利用录制好的图进行推理。

  • g.replay(): 替代了 model.forward()
  • 显存拷贝 ( .copy_ ): 看起来是额外操作,但在 GPU 内部(D2D copy)极快,远小于 Python 调用函数的开销。

这种方法使小参数量模型(如 0.1B, 0.4B, 1.5B)获得 2 到 5 倍推理速度提升,对于大模型的提升幅度会减小,因为大模型的主要瓶颈在于计算而非调度。

API_DEMO_CHAT.py 详解

API_DEMO_CHAT 是一个基于 RWKV pip 库的开发 Demo,用于实现基于命令行的聊天机器人

下文将以详细的注释,分段介绍这个聊天机器人 DEMO 的代码设计。

1. 环境设置与依赖导入

########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

print("RWKV Chat Simple Demo")  # 打印一个简单的消息,表明这是 RWKV 聊天的简单演示。
import os, copy, types, gc, sys, re  # 导入操作系统、对象复制、类型、垃圾回收、系统、正则表达式等包
import numpy as np  # 导入 numpy 库
from prompt_toolkit import prompt  # 从 prompt_toolkit 导入 prompt,用于命令行输入
import torch  # 导入 pytorch 库

这部分代码是导入一些使用 RWKV 模型推理时需要用到的包,需要注意以下两点:

  • torch 版本最低 1.13 ,推荐 2.x+cu121
  • 需要先 pip install rwkv
# 优化 PyTorch 设置,允许使用 tf32 
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True

# os.environ["RWKV_V7_ON"] = '1' # 启用 RWKV-7 模型
os.environ["RWKV_JIT_ON"] = "1" # 启用 JIT 编译
os.environ["RWKV_CUDA_ON"= "0"  # 禁用原生 CUDA 算子,改成 '1' 表示启用 CUDA 算子(速度更快,但需要 c++ 编译器和 CUDA 库)

在推理 RWKV-7 模型时,请务必将 os.environ["RWKV_V7_ON"] 设置为 1

这里是一些加快推理速度的 torch 设置和操作环境的优化项。

2. 加载模型与设置参数

from rwkv.model import RWKV  # 从 RWKV 模型库中导入 RWKV 类,用于加载和操作 RWKV 模型。
from rwkv.utils import PIPELINE  # 从 RWKV 工具库中导入 PIPELINE,用于数据的编码和解码

args = types.SimpleNamespace()

args.strategy = "cuda fp16"  # 模型推理的设备和精度,使用 CUDA (GPU)并采用 FP16 精度
args.MODEL_NAME = "E://RWKV-Runner//models//rwkv-final-v6-2.1-1b6"  # 指定 RWKV 模型的路径,建议写绝对路径

这一段引入了 RWKV 工具包中的两个工具类 RWKV 和 PIPELINE ,同时指定了加载 RWKV 模型的设备精度,以及 RWKV 模型的本地文件路径。

args.strategy 会影响模型的生成效果和生成速度,cuda fp16 是最推荐的配置。

如果你显存不足,可以尝试 cuda fp16i8(Int8 量化)。更多配置选项请参考 RWKV pip 库的推理精度与显存需求

# STATE_NAME = None # 不使用 State

# 指定要加载的 State 文件路径。
STATE_NAME = "E://RWKV-Runner//models//rwkv-x060-eng_single_round_qa-1B6-20240516-ctx2048"  # 指定要加载的自定义 State 文件路径。

这一段决定是否要加载一个 State 文件,"None" 表示不加载自定义 State ,如需加载请填写 State 文件的绝对路径。

State 是 RWKV 这类 RNN 模型特有的状态。通过搭载自定义的 State 文件,可以强化 RWKV 模型在不同任务上的表现。(类似于增强插件)

RWKV State 的介绍和用法可以参照 State 文件介绍和用法文章。

# 设置模型的解码参数
GEN_TEMP = 1.0
GEN_TOP_P = 0.3
GEN_alpha_presence = 0.5
GEN_alpha_frequency = 0.5
GEN_penalty_decay = 0.996

# 判断是否加载了一个 State 文件。如果指定了 State ,则调整生成参数,使回答的效果更好。
if STATE_NAME != None:
    GEN_TOP_P = 0.2
    GEN_alpha_presence = 0.3
    GEN_alpha_frequency = 0.3

CHUNK_LEN = 256  # 对输入进行分块处理

这里主要是设置加载或不加载 State 时, RWKV 模型分别使用哪些解码参数。

有关 RWKV 解码参数的含义和作用,请查看RWKV 解码参数文档

指定一个自定义 State 文件后,我们希望模型能更好地遵循 State 中的格式和风格,所以调低了 topp 参数和惩罚参数

CHUNK_LEN 将输入文本切分成指定大小的块。这个数值越大,模型并行处理的文本越多,但使用的显存也更多。在显存不足时建议调整到 128 或者 64。

3. 初始化和 prefill 阶段

print(f"Loading model - {args.MODEL_NAME}")# 打印模型的加载消息
model = RWKV(model=args.MODEL_NAME, strategy=args.strategy)  # 加载 RWKV 模型。
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")  # 初始化 PIPELINE ,使用 RWKV-World 词表处理输入和输出的编码/解码。

这一段开始使用前面设置的 strategy解码参数加载 RWKV 模型。

如果你希望模型加载完后也有提示,可以在这一段末尾插入:print(f"{args.MODEL_NAME} - 模型加载完毕")

model_tokens = []
model_state = None

# 如果指定了 STATE_NAME,则加载自定义 State 文件,并初始化模型 State
if STATE_NAME != None:
    args = model.args  # 获取模型参数
    state_raw = torch.load(STATE_NAME + '.pth')  # 从指定的 State 文件中加载 State 数据
    state_init = [None for i in range(args.n_layer * 3)]  # 初始化状态列表
    for i in range(args.n_layer): #开始循环,遍历每一层。
        dd = model.strategy[i]  # 获取模型每一层的加载策略
        dev = dd.device  # 获取每一层的加载设备(如 GPU)
        atype = dd.atype  # 获取每一层的数据类型(FP32/FP16 或 int8 等)
        # 初始化模型的状态
        state_init[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
        state_init[i*3+1] = state_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
        state_init[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
    model_state = copy.deepcopy(state_init)  # 复制初始化的状态

这一段代码用于加载自定义的 State 文件,将其写入模型的初始化 State 中。

通常无需修改这部分代码。

def run_rnn(ctx):
    # 定义两个全局变量,用于更新 token 和模型状态(state)
    global model_tokens, model_state
    ctx = ctx.replace("\r\n", "\n")  # 将文本中的 CRLF(Windows 系统的换行符)转换为 LF(Linux 系统的换行符)
    tokens = pipeline.encode(ctx)  # 基于 RWKV 模型的词汇表,将文本编码为 tokens
    tokens = [int(x) for x in tokens]  # 将 tokens 转换为整数(int)列表,确保类型一致性
    model_tokens += tokens  # 将 tokens 添加到全局的模型 token 列表中

    while len(tokens) > 0:  # 使用一个 while 循环执行模型前向传播,直到所有 tokens 处理完毕
        out, model_state = model.forward(tokens[:CHUNK_LEN], model_state)  # 模型前向传播,处理大小为 CHUNK_LEN 的 token 列表,并更新模型状态
        tokens = tokens[CHUNK_LEN:]  # 移除已处理的 tokens 块,并继续处理剩余的 tokens

    return out  # 返回模型的 prefill 结果

这是控制 RWKV 模型使用 RNN 模式进行 prefill 的函数,这个函数会将 ctx(前文)切成长度为 CHUNK_LEN 的段落,一段段送入 RNN 处理,最后得到处理完前文后的 model_state 和 out 。

这个函数接收一个 ctx 参数,通常是文本(string)。然后依次对文本和文本转化的 token 进行了几项处理:

  1. 使用 replace 方法将文本的换行符统一为\n ,因为 RWKV 模型的训练数据集使用 \n 作为标准换行符格式。
  2. 使用 pipeline.encode 方法,将用户的输入文本按照 RWKV-World 词表转换成对应的 token 。
  3. 将 tokens 转换为整数(int)列表,确保类型一致性
  4. 基于当前 token 前向传播,并行处理输入文本,更新模型状态并返回 out

注意,函数返回的 out 不是具体的 token 或文本,它返回的是模型对下一个 token 的原始预测(张量)。

要将 out 转换为实际的 token 或文本,需要通过采样(例如后文中的 pipeline.sample_logits 函数)预测下一个 token ,再从 token decode 成文本

# 如果没有加载自定义 State ,则使用初始提示进行对话
if STATE_NAME == None:
    init_ctx = "User: hi" + "\n\n"
    init_ctx += "Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it." + "\n\n"
    run_rnn(init_ctx)  # 运行 RNN 模式对初始提示文本进行 prefill
    print(init_ctx, end="")  # 打印初始化对话文本

如果未加载任何 State 文件,则使用一段默认的对话文本进行 prefill 。

4. 推理阶段

# 从用户输入中读取消息、循环生成下一个 token 
while True:
    msg = prompt("User: ")  # 从用户输入中读取消息,存到 msg 变量
    msg = msg.strip()  # 使用 strip 方法去除消息的首尾空格
    msg = re.sub(r"\n+", "\n", msg)  # 替换多个换行符为单个换行符
    if len(msg) > 0:  # 如果处理完后,用户输入的消息非空
        occurrence = {}  # 使用 occurrence 字典这个字典用于记录每个 token 在生成上下文中出现的次数,等会用在实现重复惩罚(Penalty)
        out_tokens = []  # 使用 out_tokens 列表记录即将输出的 tokens
        out_last = 0  # 用于记录上一次生成的 token 位置

        out = run_rnn("User: " + msg + "\n\nAssistant:")  # 将用户输入拼接成 RWKV 数据集的对话格式,进行 prefill  
        print("\nAssistant:", end="")  # 打印 "Assistant:" 标签

        for i in range(99999):  
            for n in occurrence: 
                out[n] -= GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency  # 应用存在惩罚和频率惩罚参数
            out[0] -= 1e10  # 禁用 END_OF_TEXT 

            token = pipeline.sample_logits(out, temperature=GEN_TEMP, top_p=GEN_TOP_P)  # 采样生成下一个 token

            out, model_state = model.forward([token], model_state)  # 模型前向传播
            model_tokens += [token] 
            out_tokens += [token]  # 将新生成的 token 添加到输出的 token 列表中

            for xxx in occurrence:
                occurrence[xxx] *= GEN_penalty_decay  # 应用衰减重复惩罚
            occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)  # 更新 token 的出现次数

            tmp = pipeline.decode(out_tokens[out_last:])  # 将最新生成的 token 解码成文本
            if ("\ufffd" not in tmp) and (not tmp.endswith("\n")):  # 当生成的文本是有效 UTF-8 字符串且不以换行符结尾时
                print(tmp, end="", flush=True) #实时打印解码得到的文本
                out_last = i + 1 #更新输出位置变量 out_last 

            if "\n\n" in tmp:  # 如果生成的文本包含双换行符,表示模型的响应已结束(可以将 \n\n 改成其他停止词)
                print(tmp, end="", flush=True) # 实时打印解码得到的文本
                break #结束本轮推理
    else:
        print("!!! Error: please say something !!!")  # 如果用户没有输入消息,提示“输入错误,说点啥吧!”

这一段是循环检测用户输入、并使用 RNN 模式进行推理,生成文本的功能代码。

以上代码的主要逻辑如下:

  1. 接收用户消息,规范空格空行,判断输入文本的内容长度
    • 如果规范后用户输入为空,则提示“请说点什么”
    • 如果规范后用户的输入非空,则进入步骤 2
  2. 将用户的输入拼接成聊天格式的 prompt ,然后进行 prefill ,获得 logits
  3. 预测 token ,并打印解码得到的文本字符
    • 应用存在惩罚(GEN_alpha_presence)和频率惩罚(GEN_alpha_frequency)
    • 基于 temperature 和 topp 参数对 out 进行采样,获得下一个 token
    • 使用新 token 前向传播,开启下一轮预测
    • 应用惩罚衰减参数(penalty_decay)调整 token 生成的概率
    • 把已经生成的 token 列表解码(decode)成字符文本
    • 实时输出解码得到的字符文本,判断文本里面有没有 \n\n 停止词。如果出现停止词,则退出本轮推理。

从推理过程可以看出,模型在每个时间步都更新隐藏状态(State),并利用当前的隐藏状态来生成下一个时间步的输出。这符合 RNN 的核心特性: 模型的每次输出依赖于前一步的生成结果

RWKV pip 库的推理精度与显存需求

下表中 ,fp16i8 指在 fp16 精度基础上进行 int8 量化。

量化可以减少 VRAM 需求,但在精度上略逊于 fp16。因此只要 VRAM 够用,尽量使用 fp16 层。

| 策略 | VRAM & RAM | 速度 | | -- | | -- | | cpu fp32 | 7B 模型需要 32GB 内存 | 使用 CPU fp32 精度加载模型,适合 Intel。对 AMD 非常慢,因为 pytorch 的 cpu gemv 在 AMD 上有问题,并且只会运行在一个单核上。 | | cpu bf16 | 7B 模型需要 16GB 内存 | 使用 CPU bf16 精度加载模型。在支持 bfloat16 的新 Intel CPU(如 Xeon Platinum)上速度较快。 | | cpu fp32i8 | 7B 模型需要 12GB 内存 | 使用 CPU int8 量化精度加载模型。速度较慢(比 cpu fp32 更慢)。 | | cuda fp16 | 7B 模型需要 15GB VRAM | 使用 fp16 精度加载模型所有层,速度最快,但对显存(VRAM)的需求也最高。 | | cuda fp16i8 | 7B 模型需要 9GB VRAM | 使用 int8 量化模型所有层,速度较快。如果设置 os.environ["RWKV_CUDA_ON"] = '1' 来编译 CUDA 内核,可减少 1~2GB VRAM 使用。 | | cuda fp16i8 *20 -> cuda fp16 | VRAM 占用介于 fp16 和 fp16i8 之间 | 将模型的前 20 层(*20 指层数)量化为 fp16i8,其余层使用 fp16 加载。 如果量化后还有较多 VRAM ,则酌情减少 fp16i8 层数(减少 20)。 如果 VRAM 不足则继续增加 fp16i8 量化层数 | | cuda fp16i8 *20+ | 比 fp16i8 使用更少 VRAM | 将模型的前 20 层(*20 指层数)量化为 fp16i8 并固定在 GPU 上,其他层按需动态加载(未固定的层加载速度会慢 3 倍,但节省 VRAM)。 如果 VRAM 不足,减少固定层数(*20)。 如果 VRAM 充足,增加固定层数。 | | cuda fp16i8 *20 -> cpu fp32 | 比 fp16i8 使用更少 VRAM,但消耗更多内存 | 将模型的前 20 层(*20)量化为 fp16i8 并固定在 GPU 上,其他层使用 CPU fp32 加载。当 CPU 性能比较强时,此策略比上一个策略(只在 GPU 上固定 20 层)更快。 如果加载 20 层还有剩余 VRAM ,则继续增加 GPU 层数。 如果没有足够 VRAM,减少 GPU 层数。 | | cuda:0 fp16 *20 -> cuda:1 fp16 | 使用双卡驱动模型 | 使用 cuda:0(卡1) fp16 加载模型的前 20 层,然后使用 cuda:1(卡2) fp16 加载剩余的层(自动计算剩余层数)。 建议在最快的 GPU 上运行更多层。 如果某张卡的 VRAM 不够,可以将 fp16 换成 fp16i8 (int8 量化)。 |

这份文档对您有帮助吗?