RWKV

RWKV pip 使用指南

以下内容将通过两份推理示例代码,指引你使用 RWKV pip 库。RWKV pip 库的原始代码可以在 ChatRWKV 仓库中找到。

带 CUDA Graph 的加速推理代码

以下是 RWKV 模型的两种推理模式:常规模式(Slow)带 CUDA Graph 的加速模式(Fast/CUDA Graph),后者通过消除 Python 调用开销实现极高的 Token 生成速度。

1. 环境设置与依赖导入

import os, time
import numpy as np
import torch

# 环境变量设置:必须在导入 rwkv 库之前设置
os.environ["RWKV_V7_ON"] = '1'       # 显式开启 RWKV-v7 模型支持(如果你用的是 v6 模型,请注释掉或设为 0)
os.environ['RWKV_JIT_ON'] = '1'      # 开启 JIT (Just-In-Time) 编译,加速算子加载
os.environ["RWKV_CUDA_ON"] = '1'     # 开启 CUDA 自定义算子(需系统安装了 CUDA 编译器,速度显著提升)

from rwkv.model import RWKV
from rwkv.utils import PIPELINE

注意: RWKV_V7_ON 仅适用于 RWKV-v7 架构的模型。如果加载的是 v6 或 v5 模型,请务必将其关闭或删除。

此外,RWKV_CUDA_ON='1' 需要你的环境中配置好了 nvcc 编译器(通常包含在 CUDA Toolkit 中),否则会回退到较慢的 PyTorch 原生实现或报错。

这里设置了运行 RWKV 所需的核心环境变量,并导入了必要的库。

2. 模型加载与参数配置

# 初始化模型
# strategy='cuda fp16': 使用 GPU (cuda) 进行计算,精度为 fp16
model = RWKV(model='/mnt/e/RWKV-Runner/models/rwkv7-g1a-0.1b-20250728-ctx4096', strategy='cuda fp16')

# 初始化 Pipeline
# 用于将文本转换为 token id (encode) 以及将 token id 转换回文本 (decode)
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")

# 生成参数设置
LENGTH_PER_TRIAL = 256  # 每次生成的 token 长度
TEMPERATURE = 1.0       # 温度:越高越随机,越低越确定
TOP_P = 0.0             # Nucleus 采样:0.0 通常意味着贪婪采样(或由库的具体实现决定,RWKV中 0 通常指argmax)

# 提示词 (Prompt)
prompt = "User: simulate SpaceX mars landing using python\n\nAssistant: <think"

这一部分加载了具体的模型权重文件,并设定了生成任务的“超参数”。

strategy='cuda fp16' 是最推荐的配置,如果你显存不足,可以尝试 cuda fp16i8(Int8 量化)。更多配置选项请参考 RWKV pip 库的推理精度与显存需求

3. 常规推理 (Slow Inference)

print('='*80 + '\nSlow inference\n' + '='*80)
print(prompt, end="")

all_tokens = []
out_last = 0

# 1. 预处理 Prompt (Prefill)
# 将提示词编码并输入模型,获取初始的 logits (out) 和状态 (state)
# state 包含了对 prompt 的所有记忆
out, state = model.forward(pipeline.encode(prompt), None)

times = []
all_times = []
t000 = time.perf_counter()

# 2. Token 生成循环
for i in range(LENGTH_PER_TRIAL):
    t00 = time.perf_counter()
    
    # 采样:根据模型输出的 logits 选择下一个 token
    token = pipeline.sample_logits(out, temperature=TEMPERATURE, top_p=TOP_P)
    all_tokens += [token]

    # 解码并打印:将生成的 token 实时转回文本显示
    tmp = pipeline.decode(all_tokens[out_last:])
    if '\ufffd' not in tmp: # 过滤掉无效的 UTF-8 字符(通常是多字节字符未接收完整时出现)
        print(tmp, end="", flush=True) 
        out_last = i+1    

    torch.cuda.synchronize() # 等待 GPU 完成上述操作,确保计时准确
    t0 = time.perf_counter()

    # 模型前向传播 (Forward)
    # 将选出的 token 和上一步的 state 输入模型,计算下一个 token 的概率
    out, state = model.forward(token, state)

    torch.cuda.synchronize()
    t1 = time.perf_counter()
    
    # 记录时间
    times.append(t1 - t0)       # 纯模型计算时间
    all_times.append(t1 - t00)  # 包含采样和 Python 逻辑的总时间

# 计算并打印速度统计
times = np.percentile(times, 50)      # 取中位数
all_times = np.percentile(all_times, 50)
print(f'\n\nToken/s = {round(1/times,2)} (forward), {round(1/all_times,2)} (full)')

这是标准的 PyTorch 逐个 Token 生成流程。

为什么叫 Slow?

在 Python 中,for 循环每一次 model.forward(token, state) 都会发起一次从 CPU 到 GPU 的内核启动调用。对于小模型或快速 GPU,Python 发起调用的时间开销(Overhead)可能比 GPU 实际计算的时间还要长。

4. 准备快速推理:CUDA Graph 录制

print('='*80 + '\nFast inference (CUDAGraph, requires rwkv pip pkg v0.8.31+)\n' + '='*80)
print(prompt, end="")

all_tokens = []
out_last = 0

# 重置状态,获取一个全零的初始状态结构
state = model.generate_zero_state()

#  关键步骤:分配静态显存 (Static Memory) 
# CUDA Graph 要求输入和输出的内存地址必须固定,不能变动。

# 1. 静态输入 Tensor (存放当前 Token 的 Embedding)
static_input = torch.empty((model.n_embd), device="cuda", dtype=torch.half)

# 2. 静态状态 Tensor (存放 RNN 隐状态)
# 需要两组:in 用于输入上一刻状态,out 用于输出当前刻状态
static_state_in = [torch.empty_like(x, device="cuda") for x in state]
static_state_out = [torch.empty_like(x, device="cuda") for x in state]

# 3. 静态输出 Tensor (存放预测下一个词的 Logits)
static_output = torch.empty((model.args.vocab_size), device="cuda", dtype=torch.half)

#  关键步骤:录制图 (Graph Capture) 
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
    # 在这个上下文管理器中,执行一次特殊的 forward 操作。
    # GPU 不会真的计算数据,而是"记住"了所有的计算步骤和依赖关系。
    # model.forward_one_alt 是专门为 CUDAGraph 优化的单步前向函数
    static_output, static_state_out = model.forward_one_alt(static_input, static_state_in)

核心原理: CUDA Graph 将一系列 GPU 操作打包成一个单一的“超级操作”。运行时,只需向预先分配好的 static_ 内存中填入数据,然后告诉 GPU Replay 这个图即可。这个过程绕过了 Python 对每一层算子的调度开销。

5. 执行快速推理:Graph Replay

# 1. 预处理 Prompt (Prefill)
# Prompt 处理阶段无法使用 CUDAGraph(因为长度不固定),所以使用常规 forward
out, state = model.forward(pipeline.encode(prompt), state)

# 2. 将 Prompt 计算后的状态复制到静态内存中,为 Graph 运行做准备
for i in range(len(state)):
    static_state_in[i].copy_(state[i])
    
# 将初始输出也复制进去(虽然如果是 greedy decoding 可能只用 state 就够,但保持一致性)
static_output.copy_(out)

times = []
all_times = []
t000 = time.perf_counter()

# 3. 快速生成循环
for i in range(LENGTH_PER_TRIAL):
    t00 = time.perf_counter()
    
    # 采样 (Sample) - 注意:这里直接从 static_output 读取数据
    token = pipeline.sample_logits(static_output, temperature=TEMPERATURE, top_p=TOP_P)
    all_tokens += [token]

    # 解码打印 (Decode)
    tmp = pipeline.decode(all_tokens[out_last:])
    if '\ufffd' not in tmp:
        print(tmp, end="", flush=True)
        out_last = i+1

    torch.cuda.synchronize()
    t0 = time.perf_counter()

    #  核心加速步骤 
    
    # A. 填充输入:直接从 Embedding 表中取出对应的向量,填入 static_input
    # 注意:这里我们不传 token id,而是直接传向量,减少图内部的查表开销
    static_input.copy_(model.z['emb.weight'][token])
    
    # B. 回放图:一键执行整个模型的前向计算
    g.replay()
    
    # C. 更新状态:将输出状态 (out) 复制回 输入状态 (in),形成 RNN 的循环
    # 这样下一次 replay 时,使用的就是更新后的状态了
    for n in range(len(state)):
        static_state_in[n].copy_(static_state_out[n])
    
    # --
    
    torch.cuda.synchronize()
    t1 = time.perf_counter()
    times.append(t1 - t0)
    all_times.append(t1 - t00)

# 打印统计
times = np.percentile(times, 50)
all_times = np.percentile(all_times, 50)
print(f'\n\nToken/s = {round(1/times,2)} (forward), {round(1/all_times,2)} (full) (note: very inefficient sample_logits)')

这一段展示了如何利用录制好的图进行推理。

  • g.replay(): 替代了 model.forward()
  • 显存拷贝 ( .copy_ ): 看起来是额外操作,但在 GPU 内部(D2D copy)极快,远小于 Python 调用函数的开销。

这种方法使小参数量模型(如 0.1B, 0.4B, 1.5B)获得 2 到 5 倍推理速度提升,对于大模型的提升幅度会减小,因为大模型的主要瓶颈在于计算而非调度。

API_DEMO_CHAT.py 详解

API_DEMO_CHAT 是一个基于 RWKV pip 库的开发 Demo,用于实现基于命令行的聊天机器人

下文将以详细的注释,分段介绍这个聊天机器人 DEMO 的代码设计。

1. 环境设置与依赖导入

########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

print("RWKV Chat Simple Demo")  # 打印一个简单的消息,表明这是 RWKV 聊天的简单演示。
import os, copy, types, gc, sys, re  # 导入操作系统、对象复制、类型、垃圾回收、系统、正则表达式等包
import numpy as np  # 导入 numpy 库
from prompt_toolkit import prompt  # 从 prompt_toolkit 导入 prompt,用于命令行输入
import torch  # 导入 pytorch 库

这部分代码是导入一些使用 RWKV 模型推理时需要用到的包,需要注意以下两点:

  • torch 版本最低 1.13 ,推荐 2.x+cu121
  • 需要先 pip install rwkv
# 优化 PyTorch 设置,允许使用 tf32 
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True

# os.environ["RWKV_V7_ON"] = '1' # 启用 RWKV-7 模型
os.environ["RWKV_JIT_ON"] = "1" # 启用 JIT 编译
os.environ["RWKV_CUDA_ON"= "0"  # 禁用原生 CUDA 算子,改成 '1' 表示启用 CUDA 算子(速度更快,但需要 c++ 编译器和 CUDA 库)

在推理 RWKV-7 模型时,请务必将 os.environ["RWKV_V7_ON"] 设置为 1

这里是一些加快推理速度的 torch 设置和操作环境的优化项。

2. 加载模型与设置参数

from rwkv.model import RWKV  # 从 RWKV 模型库中导入 RWKV 类,用于加载和操作 RWKV 模型。
from rwkv.utils import PIPELINE  # 从 RWKV 工具库中导入 PIPELINE,用于数据的编码和解码

args = types.SimpleNamespace()

args.strategy = "cuda fp16"  # 模型推理的设备和精度,使用 CUDA (GPU)并采用 FP16 精度
args.MODEL_NAME = "E://RWKV-Runner//models//rwkv-final-v6-2.1-1b6"  # 指定 RWKV 模型的路径,建议写绝对路径

这一段引入了 RWKV 工具包中的两个工具类 RWKV 和 PIPELINE ,同时指定了加载 RWKV 模型的设备精度,以及 RWKV 模型的本地文件路径。

args.strategy 会影响模型的生成效果和生成速度,cuda fp16 是最推荐的配置。

如果你显存不足,可以尝试 cuda fp16i8(Int8 量化)。更多配置选项请参考 RWKV pip 库的推理精度与显存需求

# STATE_NAME = None # 不使用 State

# 指定要加载的 State 文件路径。
STATE_NAME = "E://RWKV-Runner//models//rwkv-x060-eng_single_round_qa-1B6-20240516-ctx2048"  # 指定要加载的自定义 State 文件路径。

这一段决定是否要加载一个 State 文件,"None" 表示不加载自定义 State ,如需加载请填写 State 文件的绝对路径。

State 是 RWKV 这类 RNN 模型特有的状态。通过搭载自定义的 State 文件,可以强化 RWKV 模型在不同任务上的表现。(类似于增强插件)

RWKV State 的介绍和用法可以参照 State 文件介绍和用法文章。

# 设置模型的解码参数
GEN_TEMP = 1.0
GEN_TOP_P = 0.3
GEN_alpha_presence = 0.5
GEN_alpha_frequency = 0.5
GEN_penalty_decay = 0.996

# 判断是否加载了一个 State 文件。如果指定了 State ,则调整生成参数,使回答的效果更好。
if STATE_NAME != None:
    GEN_TOP_P = 0.2
    GEN_alpha_presence = 0.3
    GEN_alpha_frequency = 0.3

CHUNK_LEN = 256  # 对输入进行分块处理

这里主要是设置加载或不加载 State 时, RWKV 模型分别使用哪些解码参数。

有关 RWKV 解码参数的含义和作用,请查看RWKV 解码参数文档

指定一个自定义 State 文件后,我们希望模型能更好地遵循 State 中的格式和风格,所以调低了 topp 参数和惩罚参数

CHUNK_LEN 将输入文本切分成指定大小的块。这个数值越大,模型并行处理的文本越多,但使用的显存也更多。在显存不足时建议调整到 128 或者 64。

3. 初始化和 prefill 阶段

print(f"Loading model - {args.MODEL_NAME}")# 打印模型的加载消息
model = RWKV(model=args.MODEL_NAME, strategy=args.strategy)  # 加载 RWKV 模型。
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")  # 初始化 PIPELINE ,使用 RWKV-World 词表处理输入和输出的编码/解码。

这一段开始使用前面设置的 strategy解码参数加载 RWKV 模型。

如果你希望模型加载完后也有提示,可以在这一段末尾插入:print(f"{args.MODEL_NAME} - 模型加载完毕")

model_tokens = []
model_state = None

# 如果指定了 STATE_NAME,则加载自定义 State 文件,并初始化模型 State
if STATE_NAME != None:
    args = model.args  # 获取模型参数
    state_raw = torch.load(STATE_NAME + '.pth')  # 从指定的 State 文件中加载 State 数据
    state_init = [None for i in range(args.n_layer * 3)]  # 初始化状态列表
    for i in range(args.n_layer): #开始循环,遍历每一层。
        dd = model.strategy[i]  # 获取模型每一层的加载策略
        dev = dd.device  # 获取每一层的加载设备(如 GPU)
        atype = dd.atype  # 获取每一层的数据类型(FP32/FP16 或 int8 等)
        # 初始化模型的状态
        state_init[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
        state_init[i*3+1] = state_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
        state_init[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
    model_state = copy.deepcopy(state_init)  # 复制初始化的状态

这一段代码用于加载自定义的 State 文件,将其写入模型的初始化 State 中。

通常无需修改这部分代码。

def run_rnn(ctx):
    # 定义两个全局变量,用于更新 token 和模型状态(state)
    global model_tokens, model_state
    ctx = ctx.replace("\r\n", "\n")  # 将文本中的 CRLF(Windows 系统的换行符)转换为 LF(Linux 系统的换行符)
    tokens = pipeline.encode(ctx)  # 基于 RWKV 模型的词汇表,将文本编码为 tokens
    tokens = [int(x) for x in tokens]  # 将 tokens 转换为整数(int)列表,确保类型一致性
    model_tokens += tokens  # 将 tokens 添加到全局的模型 token 列表中

    while len(tokens) > 0:  # 使用一个 while 循环执行模型前向传播,直到所有 tokens 处理完毕
        out, model_state = model.forward(tokens[:CHUNK_LEN], model_state)  # 模型前向传播,处理大小为 CHUNK_LEN 的 token 列表,并更新模型状态
        tokens = tokens[CHUNK_LEN:]  # 移除已处理的 tokens 块,并继续处理剩余的 tokens

    return out  # 返回模型的 prefill 结果

这是控制 RWKV 模型使用 RNN 模式进行 prefill 的函数,这个函数会将 ctx(前文)切成长度为 CHUNK_LEN 的段落,一段段送入 RNN 处理,最后得到处理完前文后的 model_state 和 out 。

这个函数接收一个 ctx 参数,通常是文本(string)。然后依次对文本和文本转化的 token 进行了几项处理:

  1. 使用 replace 方法将文本的换行符统一为\n ,因为 RWKV 模型的训练数据集使用 \n 作为标准换行符格式。
  2. 使用 pipeline.encode 方法,将用户的输入文本按照 RWKV-World 词表转换成对应的 token 。
  3. 将 tokens 转换为整数(int)列表,确保类型一致性
  4. 基于当前 token 前向传播,并行处理输入文本,更新模型状态并返回 out

注意,函数返回的 out 不是具体的 token 或文本,它返回的是模型对下一个 token 的原始预测(张量)。

要将 out 转换为实际的 token 或文本,需要通过采样(例如后文中的 pipeline.sample_logits 函数)预测下一个 token ,再从 token decode 成文本

# 如果没有加载自定义 State ,则使用初始提示进行对话
if STATE_NAME == None:
    init_ctx = "User: hi" + "\n\n"
    init_ctx += "Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it." + "\n\n"
    run_rnn(init_ctx)  # 运行 RNN 模式对初始提示文本进行 prefill
    print(init_ctx, end="")  # 打印初始化对话文本

如果未加载任何 State 文件,则使用一段默认的对话文本进行 prefill 。

4. 推理阶段

# 从用户输入中读取消息、循环生成下一个 token 
while True:
    msg = prompt("User: ")  # 从用户输入中读取消息,存到 msg 变量
    msg = msg.strip()  # 使用 strip 方法去除消息的首尾空格
    msg = re.sub(r"\n+", "\n", msg)  # 替换多个换行符为单个换行符
    if len(msg) > 0:  # 如果处理完后,用户输入的消息非空
        occurrence = {}  # 使用 occurrence 字典这个字典用于记录每个 token 在生成上下文中出现的次数,等会用在实现重复惩罚(Penalty)
        out_tokens = []  # 使用 out_tokens 列表记录即将输出的 tokens
        out_last = 0  # 用于记录上一次生成的 token 位置

        out = run_rnn("User: " + msg + "\n\nAssistant:")  # 将用户输入拼接成 RWKV 数据集的对话格式,进行 prefill  
        print("\nAssistant:", end="")  # 打印 "Assistant:" 标签

        for i in range(99999):  
            for n in occurrence: 
                out[n] -= GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency  # 应用存在惩罚和频率惩罚参数
            out[0] -= 1e10  # 禁用 END_OF_TEXT 

            token = pipeline.sample_logits(out, temperature=GEN_TEMP, top_p=GEN_TOP_P)  # 采样生成下一个 token

            out, model_state = model.forward([token], model_state)  # 模型前向传播
            model_tokens += [token] 
            out_tokens += [token]  # 将新生成的 token 添加到输出的 token 列表中

            for xxx in occurrence:
                occurrence[xxx] *= GEN_penalty_decay  # 应用衰减重复惩罚
            occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)  # 更新 token 的出现次数

            tmp = pipeline.decode(out_tokens[out_last:])  # 将最新生成的 token 解码成文本
            if ("\ufffd" not in tmp) and (not tmp.endswith("\n")):  # 当生成的文本是有效 UTF-8 字符串且不以换行符结尾时
                print(tmp, end="", flush=True) #实时打印解码得到的文本
                out_last = i + 1 #更新输出位置变量 out_last 

            if "\n\n" in tmp:  # 如果生成的文本包含双换行符,表示模型的响应已结束(可以将 \n\n 改成其他停止词)
                print(tmp, end="", flush=True) # 实时打印解码得到的文本
                break #结束本轮推理
    else:
        print("!!! Error: please say something !!!")  # 如果用户没有输入消息,提示“输入错误,说点啥吧!”

这一段是循环检测用户输入、并使用 RNN 模式进行推理,生成文本的功能代码。

以上代码的主要逻辑如下:

  1. 接收用户消息,规范空格空行,判断输入文本的内容长度
    • 如果规范后用户输入为空,则提示“请说点什么”
    • 如果规范后用户的输入非空,则进入步骤 2
  2. 将用户的输入拼接成聊天格式的 prompt ,然后进行 prefill ,获得 logits
  3. 预测 token ,并打印解码得到的文本字符
    • 应用存在惩罚(GEN_alpha_presence)和频率惩罚(GEN_alpha_frequency)
    • 基于 temperature 和 topp 参数对 out 进行采样,获得下一个 token
    • 使用新 token 前向传播,开启下一轮预测
    • 应用惩罚衰减参数(penalty_decay)调整 token 生成的概率
    • 把已经生成的 token 列表解码(decode)成字符文本
    • 实时输出解码得到的字符文本,判断文本里面有没有 \n\n 停止词。如果出现停止词,则退出本轮推理。

从推理过程可以看出,模型在每个时间步都更新隐藏状态(State),并利用当前的隐藏状态来生成下一个时间步的输出。这符合 RNN 的核心特性: 模型的每次输出依赖于前一步的生成结果

RWKV pip 库的推理精度与显存需求

下表中 ,fp16i8 指在 fp16 精度基础上进行 int8 量化。

量化可以减少 VRAM 需求,但在精度上略逊于 fp16。因此只要 VRAM 够用,尽量使用 fp16 层。

策略VRAM & RAM速度
cpu fp327B 模型需要 32GB 内存使用 CPU fp32 精度加载模型,适合 Intel。对 AMD 非常慢,因为 pytorch 的 cpu gemv 在 AMD 上有问题,并且只会运行在一个单核上。
cpu bf167B 模型需要 16GB 内存使用 CPU bf16 精度加载模型。在支持 bfloat16 的新 Intel CPU(如 Xeon Platinum)上速度较快。
cpu fp32i87B 模型需要 12GB 内存使用 CPU int8 量化精度加载模型。速度较慢(比 cpu fp32 更慢)。
cuda fp167B 模型需要 15GB VRAM使用 fp16 精度加载模型所有层,速度最快,但对显存(VRAM)的需求也最高。
cuda fp16i87B 模型需要 9GB VRAM使用 int8 量化模型所有层,速度较快。如果设置 os.environ["RWKV_CUDA_ON"] = '1' 来编译 CUDA 内核,可减少 1~2GB VRAM 使用。
cuda fp16i8 *20 -> cuda fp16VRAM 占用介于 fp16 和 fp16i8 之间将模型的前 20 层(*20 指层数)量化为 fp16i8,其余层使用 fp16 加载。 如果量化后还有较多 VRAM ,则酌情减少 fp16i8 层数(减少 20)。 如果 VRAM 不足则继续增加 fp16i8 量化层数
cuda fp16i8 *20+比 fp16i8 使用更少 VRAM将模型的前 20 层(*20 指层数)量化为 fp16i8 并固定在 GPU 上,其他层按需动态加载(未固定的层加载速度会慢 3 倍,但节省 VRAM)。 如果 VRAM 不足,减少固定层数(*20)。 如果 VRAM 充足,增加固定层数。
cuda fp16i8 *20 -> cpu fp32比 fp16i8 使用更少 VRAM,但消耗更多内存将模型的前 20 层(*20)量化为 fp16i8 并固定在 GPU 上,其他层使用 CPU fp32 加载。当 CPU 性能比较强时,此策略比上一个策略(只在 GPU 上固定 20 层)更快。 如果加载 20 层还有剩余 VRAM ,则继续增加 GPU 层数。 如果没有足够 VRAM,减少 GPU 层数。
cuda:0 fp16 *20 -> cuda:1 fp16使用双卡驱动模型使用 cuda:0(卡1) fp16 加载模型的前 20 层,然后使用 cuda:1(卡2) fp16 加载剩余的层(自动计算剩余层数)。 建议在最快的 GPU 上运行更多层。 如果某张卡的 VRAM 不够,可以将 fp16 换成 fp16i8 (int8 量化)。
这份文档对您有帮助吗?
联系我们© 2026 RWKV. All rights reserved.粤ICP备2024242518号-1