Unit testing is essential for ensuring your Flyte tasks work correctly. Flyte 2.0 provides flexible testing approaches that allow you to test both your business logic and Flyte-specific features like type transformations and caching.
Understanding Task Invocation
When working with functions decorated with @env.task, there are two ways to invoke them, each with different behavior:
Direct Function Invocation
When you call a task directly like a regular Python function:
result = my_task(x=10, y=20)Flyte features are NOT invoked, including:
- Type transformations and serialization
- Caching
- Data validation
This behaves exactly like calling a regular Python function, making it ideal for testing your business logic.
Using flyte.run()
When you invoke a task using flyte.run():
run = flyte.run(my_task, x=10, y=20)
result = run.outputs()Flyte features ARE invoked, including:
- Type transformations and serialization
- Data validation
- Type checking (raises
flyte.errorsif types are not supported or restricted)
This allows you to test Flyte-specific behavior like serialization and caching.
Testing Business Logic
For most unit tests, you want to verify your business logic works correctly. Use direct function invocation for this:
import flyte
env = flyte.TaskEnvironment("my_env")
@env.task
def add(a: int, b: int) -> int:
return a + b
def test_add():
result = add(a=3, b=5)
assert result == 8Testing Async Tasks
Async tasks work the same way with direct invocation:
import pytest
@env.task
async def subtract(a: int, b: int) -> int:
return a - b
@pytest.mark.asyncio
async def test_subtract():
result = await subtract(a=10, b=4)
assert result == 6Testing Nested Tasks
When tasks call other tasks, direct invocation continues to work without any Flyte overhead:
@env.task
def nested(a: int, b: int) -> int:
return add(a, b) # Calls the add task directly
def test_nested():
result = nested(3, 5)
assert result == 8Testing Type Transformations and Serialization
When you need to test how Flyte handles data types, serialization, or caching, use flyte.run():
@pytest.mark.asyncio
async def test_add_with_flyte_run():
run = flyte.run(add, 3, 5)
assert run.outputs() == 8Testing Type Restrictions
Some types may not be supported or may be restricted. Use flyte.run() to test that these restrictions are enforced:
from typing import Tuple
import flyte.errors
@env.task
def not_supported_types(x: Tuple[str, str]) -> str:
return x[0]
@pytest.mark.asyncio
async def test_not_supported_types():
# Direct invocation works fine
result = not_supported_types(x=("a", "b"))
assert result == "a"
# flyte.run enforces type restrictions
with pytest.raises(flyte.errors.RestrictedTypeError):
flyte.run(not_supported_types, x=("a", "b"))Testing Nested Tasks with Serialization
You can also test nested task execution with Flyte’s full machinery:
@pytest.mark.asyncio
async def test_nested_with_run():
run = flyte.run(nested, 3, 5)
assert run.outputs() == 8Testing Traced Functions
Functions decorated with @flyte.trace can be tested similarly to tasks:
@flyte.trace
async def traced_multiply(a: int, b: int) -> int:
return a * b
@pytest.mark.asyncio
async def test_traced_multiply():
result = await traced_multiply(a=6, b=7)
assert result == 42Best Practices
-
Test logic with direct invocation: For most unit tests, call tasks directly to test your business logic without Flyte overhead.
-
Test serialization with
flyte.run(): Useflyte.run()when you need to verify:- Type transformations work correctly
- Data serialization/deserialization
- Caching behavior
- Type restrictions are enforced
-
Use standard testing frameworks: Flyte tasks work with pytest, unittest, and other Python testing frameworks.
-
Test async tasks properly: Use
@pytest.mark.asynciofor async tasks and await their results. -
Mock external dependencies: Use standard Python mocking techniques for external services, databases, etc.
Quick Reference
| Test Scenario | Method | Example |
|---|---|---|
| Business logic (sync) | Direct call | result = task(x=10) |
| Business logic (async) | Direct await | result = await task(x=10) |
| Type transformations | flyte.run() |
r = flyte.run(task, x=10) |
| Data serialization | flyte.run() |
r = flyte.run(task, x=10) |
| Caching behavior | flyte.run() |
r = flyte.run(task, x=10) |
| Type restrictions | flyte.run() + pytest.raises |
pytest.raises(flyte.errors.RestrictedTypeError) |
Example Test Suite
Here’s a complete example showing different testing approaches:
import pytest
import flyte
import flyte.errors
env = flyte.TaskEnvironment("test_env")
@env.task
def add(a: int, b: int) -> int:
return a + b
@env.task
async def subtract(a: int, b: int) -> int:
return a - b
# Test business logic directly
def test_add_logic():
result = add(a=3, b=5)
assert result == 8
@pytest.mark.asyncio
async def test_subtract_logic():
result = await subtract(a=10, b=4)
assert result == 6
# Test with Flyte serialization
@pytest.mark.asyncio
async def test_add_serialization():
run = flyte.run(add, 3, 5)
assert run.outputs() == 8
@pytest.mark.asyncio
async def test_subtract_serialization():
run = flyte.run(subtract, a=10, b=4)
assert run.outputs() == 6Future Improvements
The Flyte SDK team is actively working on improvements for advanced unit testing scenarios, particularly around initialization and setup for complex test cases. Additional utilities and patterns may be introduced in future releases to make unit testing even more streamlined.