数据准备
prompt 文件格式、单条 prompt schema、图像/条件输入,以及 prompt 如何展开为 rollout group。
UniRL 训练是 prompt-first:数据文件提供 prompt,rollout engine 生成 media,reward 组件再打分。 本页说明可接受的文件格式与单条 prompt 的 schema。数据集挂载位置见 数据与模型。
文件格式
用 DATA_PATH(环境变量,recipe 会插值它)或覆盖 recipe 的 data_source data_path(Hydra override)指定 prompt 文件。读取器
(unirl/data/datasets.py)接受三种扩展名,其它会报错:
| 扩展名 | 解析 |
|---|---|
.txt | 每个非空行一个 prompt,每行变成 {"prompt": <line>}。 |
.jsonl | 每个非空行一个 JSON 对象。 |
.json | 字符串或对象的 list;或带 prompts list、caption、或可配置 prompt key(默认 prompt)的 dict。 |
最小 JSON:
[
{"prompt": "A watercolor landscape with snowy mountains at sunrise."},
{"prompt": "A cinematic portrait of a robot reading under warm light."}
]纯 .txt 文件(每行一个 prompt,如已提交的 datasets/pickscore/train.txt)也适用于文生视频 recipe:
A drone shot flying over a misty pine forest at dawn.
Time-lapse of clouds rolling over a desert canyon.单条 Prompt Schema
每个对象会被规范化为一条 prompt example:
| 字段 | 必填 | 说明 |
|---|---|---|
prompt(或 caption) | 是 | 非空文本。 |
prompt_id | 否 | 省略时自动生成为 {文件名}:{索引}。 |
metadata | 否 | 自由 dict;省略时其余顶层键自动归入 metadata。 |
media / media_refs | 否 | media 引用列表,每项为 {modality, role, uri}。 |
若省略 metadata,其余顶层键(除 prompt、caption、media、media_refs、metadata、prompt_id
以外的)会并入它;若显式传入 metadata dict,则按原样使用,额外字段请放进它里面。遗留的预计算 embedding 字段
(如 prompt_embed_path、prompt_embeds)会被硬报错——embedding 在 runtime 计算。
数据文件里没有 negative_prompt,也没有逐行 seed。guidance scale、seed、分辨率来自 cfg.sampling,
不来自 manifest 行。
图像条件 / Edit / I2V 输入
图生视频、编辑或其它带条件的 recipe,通过 media_refs 附加条件图像:
{
"prompt": "Animate this scene with gentle falling snow.",
"media_refs": [
{"modality": "image", "role": "condition", "uri": "frames/scene_01.png"}
]
}- 相对 URI 相对数据文件所在目录解析。
- 绝对路径与
http://、https://、s3://、gs://URI 原样透传。 - 目前 driver 每条 prompt 只加载一个
(modality="image", role="condition")引用;其它 modality/role 组合会抛NotImplementedError。 - 数据契约里没有 video URI role:文生视频用
.txtprompt,图生视频用图像 condition 引用。
Prompt 如何变成 Rollout Group
两个旋钮控制 batch 形状,作用在不同阶段:
prompts_per_rollout是每个 rollout 采样的不同 prompt 数(data loader 的 batch size)。prompt 不会被预先复制。sampling.samples_per_prompt随后在 rollout pipeline 里把每条 prompt 复制k份,组成 N-sample 的 GRPO group。 同组样本共享一个group_id,sample_id形如prompt:<gid>:sample:<j>。
所以一个 rollout 产出 prompts_per_rollout × sampling.samples_per_prompt 个样本。
数据源选择
| 数据源 | 何时用 | 由谁选择 |
|---|---|---|
MultimodalRLDataSource | 真实训练;读取配置的 data_path,shuffle,丢弃最后一个不完整 batch | recipe 设置 data_source._target_: unirl.data.data_source.MultimodalRLDataSource(默认) |
DefaultDataSource | smoke 检查;忽略 data_path,循环少量内置 prompt | recipe 把 data_source._target_ 指向 unirl.data.data_source.DefaultDataSource |
EVAL_DATA_PATH 指向独立的 eval prompt 文件(按确定性顺序加载);训练 batch
始终来自配置的 data_path。eval 路径的当前状态见 评测。
完整示例
# 1. 准备 prompts.json({"prompt": ...} 对象的 list)。
# 2. 让 DATA_PATH 指向它,并启动一个数据源读文件的 recipe。
DATA_PATH=/abs/path/prompts.json \
OUTPUT_DIR=/abs/path/outputs/run1 \
bash examples/run_experiment_single_node.sh diffusion/sd3_trainside启动 Ray 作业前先做 compose check:
DATA_PATH=/abs/path/prompts.json \
python -m unirl.train_diffusion --config-name=diffusion/sd3_trainside --cfg job --resolve