flyteplugins.wandb
Key features:
- Automatic W&B run initialization with
@wandb_initdecorator - Automatic W&B links in Flyte UI pointing to runs and sweeps
- Parent/child task support with automatic run reuse
- W&B sweep creation and management with
@wandb_sweepdecorator - Configuration management with
wandb_config()andwandb_sweep_config() - Distributed training support (auto-detects PyTorch DDP/torchrun)
Basic usage:
-
Simple task with W&B logging:
from flyteplugins.wandb import wandb_init, get_wandb_run @wandb_init(project="my-project", entity="my-team") @env.task async def train_model(learning_rate: float) -> str: wandb_run = get_wandb_run() wandb_run.log({"loss": 0.5, "learning_rate": learning_rate}) return wandb_run.id -
Parent/Child Tasks with Run Reuse:
@wandb_init # Automatically reuses parent's run ID @env.task async def child_task(x: int) -> str: wandb_run = get_wandb_run() wandb_run.log({"child_metric": x * 2}) return wandb_run.id @wandb_init(project="my-project", entity="my-team") @env.task async def parent_task() -> str: wandb_run = get_wandb_run() wandb_run.log({"parent_metric": 100}) # Child reuses parent's run by default (run_mode="auto") await child_task(5) return wandb_run.id -
Configuration with context manager:
from flyteplugins.wandb import wandb_config r = flyte.with_runcontext( custom_context=wandb_config( project="my-project", entity="my-team", tags=["experiment-1"] ) ).run(train_model, learning_rate=0.001) -
Creating new runs for child tasks:
@wandb_init(run_mode="new") # Always creates a new run @env.task async def independent_child() -> str: wandb_run = get_wandb_run() wandb_run.log({"independent_metric": 42}) return wandb_run.id -
Running sweep agents in parallel:
import asyncio from flyteplugins.wandb import wandb_sweep, get_wandb_sweep_id, get_wandb_context @wandb_init async def objective(): wandb_run = wandb.run config = wandb_run.config ... wandb_run.log({"loss": loss_value}) @wandb_sweep @env.task async def sweep_agent(agent_id: int, sweep_id: str, count: int = 5) -> int: wandb.agent(sweep_id, function=objective, count=count, project=get_wandb_context().project) return agent_id @wandb_sweep @env.task async def run_parallel_sweep(num_agents: int = 2, trials_per_agent: int = 5) -> str: sweep_id = get_wandb_sweep_id() # Launch agents in parallel agent_tasks = [ sweep_agent(agent_id=i + 1, sweep_id=sweep_id, count=trials_per_agent) for i in range(num_agents) ] # Wait for all agents to complete await asyncio.gather(*agent_tasks) return sweep_id # Run with 2 parallel agents r = flyte.with_runcontext( custom_context={ **wandb_config(project="my-project", entity="my-team"), **wandb_sweep_config( method="random", metric={"name": "loss", "goal": "minimize"}, parameters={ "learning_rate": {"min": 0.0001, "max": 0.1}, "batch_size": {"values": [16, 32, 64]}, } ) } ).run(run_parallel_sweep, num_agents=2, trials_per_agent=5) -
Distributed Training Support:
The plugin auto-detects distributed training from environment variables (RANK, WORLD_SIZE, LOCAL_RANK, etc.) set by torchrun/torch.distributed.elastic.
The
rank_scopeparameter controls the scope of run creation:"global"(default): Global scope - 1 run/group across all workers"worker": Worker scope - 1 run/group per worker
By default (
run_mode="auto",rank_scope="global"):- Single-node: Only rank 0 logs (1 run)
- Multi-node: Only global rank 0 logs (1 run)
from flyteplugins.pytorch.task import Elastic from flyteplugins.wandb import wandb_init, get_wandb_run torch_env = flyte.TaskEnvironment( name="torch_env", resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "5Gi"), gpu="V100:4"), plugin_config=Elastic(nnodes=2, nproc_per_node=2), ) @wandb_init @torch_env.task async def train_distributed(): torch.distributed.init_process_group("nccl") # Only global rank 0 gets a W&B run, other ranks get None run = get_wandb_run() if run: run.log({"loss": loss}) return run.id if run else "non-primary-rank"Use
rank_scope="worker"to get 1 run per worker:@wandb_init(rank_scope="worker") @torch_env.task async def train_distributed_per_worker(): # Multi-node: local rank 0 of each worker gets a W&B run (1 run per worker) run = get_wandb_run() if run: run.log({"loss": loss}) return run.id if run else "non-primary-rank"Use
run_mode="shared"for all ranks to log to shared run(s):@wandb_init(run_mode="shared") # rank_scope="global": 1 shared run across all ranks @torch_env.task async def train_distributed_shared(): # All ranks log to the same W&B run (with x_label to identify each rank) run = get_wandb_run() run.log({"rank_metric": value}) return run.id @wandb_init(run_mode="shared", rank_scope="worker") # 1 shared run per worker @torch_env.task async def train_distributed_shared_per_worker(): run = get_wandb_run() run.log({"rank_metric": value}) return run.idUse
run_mode="new"for each rank to have its own W&B run:@wandb_init(run_mode="new") # rank_scope="global": all runs in 1 group @torch_env.task async def train_distributed_separate_runs(): # Each rank gets its own W&B run (grouped in W&B UI) # Run IDs: {base}-rank-{global_rank} run = get_wandb_run() run.log({"rank_metric": value}) return run.id @wandb_init(run_mode="new", rank_scope="worker") # runs grouped per worker @torch_env.task async def train_distributed_separate_runs_per_worker(): run = get_wandb_run() run.log({"rank_metric": value}) return run.id
Decorator order: @wandb_init or @wandb_sweep must be the outermost decorator:
@wandb_init
@env.task
async def my_task():
...Directory
Classes
| Class | Description |
|---|---|
Wandb |
Generates a Weights & Biases run link. |
WandbSweep |
Generates a Weights & Biases Sweep link. |
Methods
| Method | Description |
|---|---|
download_wandb_run_dir() |
Download wandb run data from wandb cloud. |
download_wandb_run_logs() |
Traced function to download wandb run logs after task completion. |
download_wandb_sweep_dirs() |
Download all run data for a wandb sweep. |
download_wandb_sweep_logs() |
Traced function to download wandb sweep logs after task completion. |
get_distributed_info() |
Get distributed training info if running in a distributed context. |
get_wandb_context() |
Get wandb config from current Flyte context. |
get_wandb_run() |
Get the current wandb run if within a @wandb_init decorated task or trace. |
get_wandb_run_dir() |
Get the local directory path for the current wandb run. |
get_wandb_sweep_context() |
Get wandb sweep config from current Flyte context. |
get_wandb_sweep_id() |
Get the current wandb sweep_id if within a @wandb_sweep decorated task. |
wandb_config() |
Create wandb configuration. |
wandb_init() |
Decorator to automatically initialize wandb for Flyte tasks and wandb sweep objectives. |
wandb_sweep() |
Decorator to create a wandb sweep and make sweep_id available. |
wandb_sweep_config() |
Create wandb sweep configuration for hyperparameter optimization. |
Methods
download_wandb_run_dir()
def download_wandb_run_dir(
run_id: typing.Optional[str],
path: typing.Optional[str],
include_history: bool,
) -> strDownload wandb run data from wandb cloud.
Downloads all run files and optionally exports metrics history to JSON. This enables access to wandb data from any task or after workflow completion.
Downloaded contents:
- summary.json - final summary metrics (always exported)
- metrics_history.json - step-by-step metrics (if include_history=True)
- Plus any files synced by wandb (requirements.txt, wandb_metadata.json, etc.)
| Parameter | Type | Description |
|---|---|---|
run_id |
typing.Optional[str] |
The wandb run ID to download. If None, uses the current run’s ID from context (useful for shared runs across tasks). |
path |
typing.Optional[str] |
Local directory to download files to. If None, downloads to /tmp/wandb_runs/{run_id}. |
include_history |
bool |
If True, exports the step-by-step metrics history to metrics_history.json. Defaults to True. |
download_wandb_run_logs()
def download_wandb_run_logs(
run_id: str,
) -> flyte.io._dir.DirTraced function to download wandb run logs after task completion.
This function is called automatically when download_logs=True is set
in @wandb_init or wandb_config(). The downloaded files appear as a
trace output in the Flyte UI.
| Parameter | Type | Description |
|---|---|---|
run_id |
str |
The wandb run ID to download. |
download_wandb_sweep_dirs()
def download_wandb_sweep_dirs(
sweep_id: typing.Optional[str],
base_path: typing.Optional[str],
include_history: bool,
) -> list[str]Download all run data for a wandb sweep.
Queries the wandb API for all runs in the sweep and downloads their files and metrics history. This is useful for collecting results from all sweep trials after completion.
| Parameter | Type | Description |
|---|---|---|
sweep_id |
typing.Optional[str] |
The wandb sweep ID. If None, uses the current sweep’s ID from context (set by @wandb_sweep decorator). |
base_path |
typing.Optional[str] |
Base directory to download files to. Each run’s files will be in a subdirectory named by run_id. If None, uses /tmp/wandb_runs/. |
include_history |
bool |
If True, exports the step-by-step metrics history to metrics_history.json for each run. Defaults to True. |
download_wandb_sweep_logs()
def download_wandb_sweep_logs(
sweep_id: str,
) -> flyte.io._dir.DirTraced function to download wandb sweep logs after task completion.
This function is called automatically when download_logs=True is set
in @wandb_sweep or wandb_sweep_config(). The downloaded files appear as a
trace output in the Flyte UI.
| Parameter | Type | Description |
|---|---|---|
sweep_id |
str |
The wandb sweep ID to download. |
get_distributed_info()
def get_distributed_info()Get distributed training info if running in a distributed context.
This function auto-detects distributed training from environment variables set by torchrun/torch.distributed.elastic.
Returns: dict | None: Dictionary with distributed info or None if not distributed. - rank: Global rank (0 to world_size-1) - local_rank: Rank within the node (0 to local_world_size-1) - world_size: Total number of processes - local_world_size: Processes per node - worker_index: Node/worker index (0 to num_workers-1) - num_workers: Total number of nodes/workers
get_wandb_context()
def get_wandb_context()Get wandb config from current Flyte context.
get_wandb_run()
def get_wandb_run()Get the current wandb run if within a @wandb_init decorated task or trace.
The run is initialized when the @wandb_init context manager is entered.
Returns None if not within a wandb_init context.
Returns:
wandb.sdk.wandb_run.Run | None: The current wandb run object or None.
get_wandb_run_dir()
def get_wandb_run_dir()Get the local directory path for the current wandb run.
Use this for accessing files written by the current task without any
network calls. For accessing files from other tasks (or after a task
completes), use download_wandb_run_dir() instead.
Returns:
Local path to wandb run directory (wandb.run.dir) or None if no
active run.
get_wandb_sweep_context()
def get_wandb_sweep_context()Get wandb sweep config from current Flyte context.
get_wandb_sweep_id()
def get_wandb_sweep_id()Get the current wandb sweep_id if within a @wandb_sweep decorated task.
Returns None if not within a wandb_sweep context.
Returns:
str | None: The sweep ID or None.
wandb_config()
def wandb_config(
project: typing.Optional[str],
entity: typing.Optional[str],
id: typing.Optional[str],
name: typing.Optional[str],
tags: typing.Optional[list[str]],
config: typing.Optional[dict[str, typing.Any]],
mode: typing.Optional[str],
group: typing.Optional[str],
run_mode: typing.Literal['auto', 'new', 'shared'],
rank_scope: typing.Literal['global', 'worker'],
download_logs: bool,
kwargs: **kwargs,
) -> flyteplugins.wandb._context._WandBConfigCreate wandb configuration.
This function works in two contexts:
- With
flyte.with_runcontext()- sets global wandb config - As a context manager - overrides config for specific tasks
| Parameter | Type | Description |
|---|---|---|
project |
typing.Optional[str] |
W&B project name |
entity |
typing.Optional[str] |
W&B entity (team or username) |
id |
typing.Optional[str] |
Unique run id (auto-generated if not provided) |
name |
typing.Optional[str] |
Human-readable run name |
tags |
typing.Optional[list[str]] |
List of tags for organizing runs |
config |
typing.Optional[dict[str, typing.Any]] |
Dictionary of hyperparameters |
mode |
typing.Optional[str] |
“online”, “offline” or “disabled” |
group |
typing.Optional[str] |
Group name for related runs |
run_mode |
typing.Literal['auto', 'new', 'shared'] |
“auto”, “new” or “shared”. Controls whether tasks create new W&B runs or share existing ones. - “auto” (default): Creates new run if no parent run exists, otherwise shares parent’s run - “new”: Always creates a new wandb run with a unique ID - “shared”: Always shares the parent’s run ID In distributed training context (single-node): - “auto” (default): Only rank 0 logs. - “shared”: All ranks log to a single shared W&B run. - “new”: Each rank gets its own W&B run (grouped in W&B UI). Multi-node: behavior depends on rank_scope. |
rank_scope |
typing.Literal['global', 'worker'] |
“global” or “worker”. Controls which ranks log in distributed training. run_mode=“auto”: - “global” (default): Only global rank 0 logs (1 run total). - “worker”: Local rank 0 of each worker logs (1 run per worker). run_mode=“shared”: - “global”: All ranks log to a single shared W&B run. - “worker”: Ranks per worker log to a single shared W&B run (1 run per worker). run_mode=“new”: - “global”: Each rank gets its own W&B run (1 run total). - “worker”: Each rank gets its own W&B run grouped per worker -> N runs. |
download_logs |
bool |
If True, downloads wandb run files after task completes and shows them as a trace output in the Flyte UI |
kwargs |
**kwargs |
wandb_init()
def wandb_init(
_func: typing.Optional[~F],
run_mode: typing.Optional[typing.Literal['auto', 'new', 'shared']],
rank_scope: typing.Optional[typing.Literal['global', 'worker']],
download_logs: typing.Optional[bool],
project: typing.Optional[str],
entity: typing.Optional[str],
kwargs,
) -> ~FDecorator to automatically initialize wandb for Flyte tasks and wandb sweep objectives.
| Parameter | Type | Description |
|---|---|---|
_func |
typing.Optional[~F] |
|
run_mode |
typing.Optional[typing.Literal['auto', 'new', 'shared']] |
|
rank_scope |
typing.Optional[typing.Literal['global', 'worker']] |
Flyte-specific rank scope - “global” or “worker”. Controls which ranks log in distributed training. run_mode=“auto”: - “global” (default): Only global rank 0 logs (1 run total). - “worker”: Local rank 0 of each worker logs (1 run per worker). run_mode=“shared”: - “global”: All ranks log to a single shared W&B run. - “worker”: Ranks per worker log to a single shared W&B run (1 run per worker). run_mode=“new”: - “global”: Each rank gets its own W&B run (1 run total). - “worker”: Each rank gets its own W&B run grouped per worker -> N runs. |
download_logs |
typing.Optional[bool] |
If True, downloads wandb run files after task completes and shows them as a trace output in the Flyte UI. If None, uses the value from wandb_config() context if set. |
project |
typing.Optional[str] |
W&B project name (overrides context config if provided) |
entity |
typing.Optional[str] |
W&B entity/team name (overrides context config if provided) |
kwargs |
**kwargs |
wandb_sweep()
def wandb_sweep(
_func: typing.Optional[~F],
project: typing.Optional[str],
entity: typing.Optional[str],
download_logs: typing.Optional[bool],
kwargs,
) -> ~FDecorator to create a wandb sweep and make sweep_id available.
This decorator:
- Creates a wandb sweep using config from context
- Makes
sweep_idavailable viaget_wandb_sweep_id() - Automatically adds a W&B sweep link to the task
- Optionally downloads all sweep run logs as a trace output (if
download_logs=True)
| Parameter | Type | Description |
|---|---|---|
_func |
typing.Optional[~F] |
|
project |
typing.Optional[str] |
W&B project name (overrides context config if provided) |
entity |
typing.Optional[str] |
W&B entity/team name (overrides context config if provided) |
download_logs |
typing.Optional[bool] |
if True, downloads all sweep run files after task completes and shows them as a trace output in the Flyte UI. If None, uses the value from wandb_sweep_config() context if set. |
kwargs |
**kwargs |
wandb_sweep_config()
def wandb_sweep_config(
method: typing.Optional[str],
metric: typing.Optional[dict[str, typing.Any]],
parameters: typing.Optional[dict[str, typing.Any]],
project: typing.Optional[str],
entity: typing.Optional[str],
prior_runs: typing.Optional[list[str]],
name: typing.Optional[str],
download_logs: bool,
kwargs: **kwargs,
) -> flyteplugins.wandb._context._WandBSweepConfigCreate wandb sweep configuration for hyperparameter optimization.
| Parameter | Type | Description |
|---|---|---|
method |
typing.Optional[str] |
Sweep method (e.g., “random”, “grid”, “bayes”) |
metric |
typing.Optional[dict[str, typing.Any]] |
|
parameters |
typing.Optional[dict[str, typing.Any]] |
Parameter definitions for the sweep |
project |
typing.Optional[str] |
W&B project for the sweep |
entity |
typing.Optional[str] |
W&B entity for the sweep |
prior_runs |
typing.Optional[list[str]] |
List of prior run IDs to include in the sweep analysis |
name |
typing.Optional[str] |
Sweep name (auto-generated as {run_name}-{action_name} if not provided) |
download_logs |
bool |
If True, downloads all sweep run files after task completes and shows them as a trace output in the Flyte UI |
kwargs |
**kwargs |