Inference Actions
Use inference actions to build model inference workflows. Use BaseInferenceAction for batch inference and BaseDeploymentAction for REST API serving via Ray Serve.
Overview
Synapse SDK provides two base classes for inference workflows:
| Base Class | Purpose | Use Case |
|---|---|---|
BaseInferenceAction | Batch inference | Processing datasets, offline predictions |
BaseDeploymentAction | REST API serving | Real-time inference endpoints via Ray Serve |
Both classes support two execution modes:
- Simple Mode: Override
execute()directly for straightforward workflows - Step-Based Mode: Use
setup_steps()to register workflow steps for complex pipelines
BaseInferenceAction
The base class for inference actions. Provides helper methods for model loading and inference workflows.
class BaseInferenceAction(BaseAction[P]):
category = PluginCategory.NEURAL_NET
progress = InferenceProgressCategories()
Progress Categories
Track inference progress with these standard categories:
| Category | Value | Description |
|---|---|---|
MODEL_LOAD | 'model_load' | Model loading and initialization |
INFERENCE | 'inference' | Running inference on inputs |
POSTPROCESS | 'postprocess' | Post-processing results |
self.set_progress(1, 3, self.progress.MODEL_LOAD)
self.set_progress(2, 3, self.progress.INFERENCE)
self.set_progress(3, 3, self.progress.POSTPROCESS)
Helper Methods
get_model
Retrieve model metadata by ID.
def get_model(self, model_id: int) -> dict[str, Any]
| Parameter | Type | Required | Description |
|---|---|---|---|
model_id | int | Yes | Model identifier |
Returns: Model metadata dictionary including file URL.
model = self.get_model(123)
print(model['name'], model['file'])
download_model
Download and extract model artifacts.
def download_model(
self,
model_id: int,
output_dir: str | Path | None = None,
) -> Path
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
model_id | int | Yes | - | Model identifier |
output_dir | str | Path | None | No | None | Directory to extract model to. Uses tempdir if None |
Returns: Path to extracted model directory.
model_path = self.download_model(123)
# model_path contains extracted model artifacts
load_model
Load model for inference. Downloads artifacts and returns model info with local path.
def load_model(self, model_id: int) -> dict[str, Any]
| Parameter | Type | Required | Description |
|---|---|---|---|
model_id | int | Yes | Model identifier |
Returns: Model metadata dict with 'path' key for local artifacts.
model_info = self.load_model(123)
model_path = model_info['path']
# Load your model framework here:
# model = torch.load(model_path / 'model.pt')
infer
Run inference on inputs. Override this method to implement your inference logic.
def infer(
self,
model: Any,
inputs: list[dict[str, Any]],
) -> list[dict[str, Any]]
| Parameter | Type | Required | Description |
|---|---|---|---|
model | Any | Yes | Loaded model (framework-specific) |
inputs | list[dict[str, Any]] | Yes | List of input dictionaries |
Returns: List of result dictionaries.
def infer(self, model, inputs):
results = []
for inp in inputs:
prediction = model.predict(inp['image'])
results.append({'prediction': prediction})
return results
Execution Modes
Simple Mode
Override execute() directly for simple workflows:
from synapse_sdk.plugins.actions.inference import BaseInferenceAction
from pydantic import BaseModel
class InferenceParams(BaseModel):
model_id: int
inputs: list[dict]
class MyInferenceAction(BaseInferenceAction[InferenceParams]):
action_name = 'inference'
def execute(self) -> dict[str, Any]:
# Load model
model_info = self.load_model(self.params.model_id)
self.set_progress(1, 3, self.progress.MODEL_LOAD)
# Run inference
results = self.infer(model_info, self.params.inputs)
self.set_progress(2, 3, self.progress.INFERENCE)
# Post-process
processed = self._postprocess(results)
self.set_progress(3, 3, self.progress.POSTPROCESS)
return {'results': processed}
def infer(self, model, inputs):
import torch
model_obj = torch.load(model['path'] + '/model.pt')
results = []
for inp in inputs:
pred = model_obj(inp['tensor'])
results.append({'prediction': pred.tolist()})
return results
Step-Based Mode
Use setup_steps() to register workflow steps for complex pipelines:
from synapse_sdk.plugins.actions.inference import (
BaseInferenceAction,
InferenceContext,
)
from synapse_sdk.plugins.steps import BaseStep, StepResult, StepRegistry
class LoadModelStep(BaseStep[InferenceContext]):
@property
def name(self) -> str:
return 'load_model'
@property
def progress_weight(self) -> float:
return 0.3
def execute(self, context: InferenceContext) -> StepResult:
# Load model using context
import torch
model_path = context.model_path
context.model = torch.load(f'{model_path}/model.pt')
return StepResult(success=True)
class InferenceStep(BaseStep[InferenceContext]):
@property
def name(self) -> str:
return 'inference'
@property
def progress_weight(self) -> float:
return 0.7
def execute(self, context: InferenceContext) -> StepResult:
for request in context.requests:
prediction = context.model(request['input'])
context.results.append({'prediction': prediction})
context.processed_count += 1
return StepResult(success=True)
class MyInferenceAction(BaseInferenceAction[InferenceParams]):
action_name = 'inference'
def setup_steps(self, registry: StepRegistry[InferenceContext]) -> None:
registry.register(LoadModelStep())
registry.register(InferenceStep())
InferenceContext
Context for inference action step-based workflows. Extends BaseStepContext with inference-specific state.
@dataclass
class InferenceContext(BaseStepContext):
params: dict[str, Any] = field(default_factory=dict)
model_id: int | None = None
model: dict[str, Any] | None = None
model_path: str | None = None
requests: list[dict[str, Any]] = field(default_factory=list)
results: list[dict[str, Any]] = field(default_factory=list)
batch_size: int = 1
processed_count: int = 0
| Attribute | Type | Description |
|---|---|---|
params | dict[str, Any] | Action parameters |
model_id | int | None | ID of the model being used |
model | dict[str, Any] | None | Loaded model information from backend |
model_path | str | None | Local path to downloaded model |
requests | list[dict[str, Any]] | Input requests to process |
results | list[dict[str, Any]] | Inference results |
batch_size | int | Batch size for processing |
processed_count | int | Number of processed items |
Example: Batch Inference
Complete example of a batch inference action with PyTorch:
from pathlib import Path
from typing import Any
import torch
from pydantic import BaseModel
from synapse_sdk.plugins.actions.inference import BaseInferenceAction
class BatchInferenceParams(BaseModel):
model_id: int
inputs: list[dict[str, Any]]
batch_size: int = 32
class BatchInferenceAction(BaseInferenceAction[BatchInferenceParams]):
"""Batch inference action for PyTorch models."""
action_name = 'batch_inference'
def execute(self) -> dict[str, Any]:
# Step 1: Load model
model_info = self.load_model(self.params.model_id)
model_path = Path(model_info['path'])
model = torch.load(model_path / 'model.pt')
model.eval()
self.set_progress(1, 3, self.progress.MODEL_LOAD)
# Step 2: Run inference in batches
results = []
inputs = self.params.inputs
batch_size = self.params.batch_size
for i in range(0, len(inputs), batch_size):
batch = inputs[i : i + batch_size]
batch_tensors = torch.stack([torch.tensor(inp['data']) for inp in batch])
with torch.no_grad():
predictions = model(batch_tensors)
for pred in predictions:
results.append({'prediction': pred.tolist()})
self.set_progress(2, 3, self.progress.INFERENCE)
# Step 3: Post-process
self.set_progress(3, 3, self.progress.POSTPROCESS)
return {
'results': results,
'processed_count': len(results),
}
def infer(self, model, inputs):
# Optional: Override for custom inference logic
pass
Deployment Actions
Use deployment actions to serve REST APIs via Ray Serve. Use BaseDeploymentAction to deploy inference endpoints.
BaseDeploymentAction
Base class for Ray Serve deployment actions. Handles Ray initialization, deployment creation, and backend registration.
class BaseDeploymentAction(BaseAction[P]):
progress = DeploymentProgressCategories()
entrypoint: type | None = None # Set to your serve deployment class
DeploymentProgressCategories
| Category | Value | Description |
|---|---|---|
INITIALIZE | 'initialize' | Ray cluster initialization |
DEPLOY | 'deploy' | Deploying to Ray Serve |
REGISTER | 'register' | Registering with backend |
Deployment Methods
ray_init
Initialize Ray cluster connection.
def ray_init(self, **kwargs: Any) -> None
deploy
Deploy the inference endpoint to Ray Serve.
def deploy(self) -> None
Internally calls a pre-flight cluster capacity gate (_check_serve_capacity) immediately before serve.run(...). See Capacity Gate below.
register_serve_application
Register the serve application with the backend.
def register_serve_application(self) -> int | None
Returns: Serve application ID if created, None otherwise.
Configuration Methods
Override these methods to customize deployment:
| Method | Default | Description |
|---|---|---|
get_serve_app_name() | SYNAPSE_PLUGIN_RELEASE_CODE env var | Serve application name |
get_route_prefix() | SYNAPSE_PLUGIN_RELEASE_CHECKSUM env var | URL route prefix |
get_ray_actor_options() | Extract from params | Ray actor options (num_cpus, num_gpus, memory) |
get_runtime_env() | {} | Ray runtime environment |
BaseServeDeployment
Base class for Ray Serve inference deployments. Inherits from BaseAction and provides model loading with multiplexing support.
class BaseServeDeployment(BaseAction):
def __init__(self, backend_url: str) -> None:
self.backend_url = backend_url
self._model_cache: dict[str, Any] = {}
Abstract methods to implement:
| Method | Description |
|---|---|
async _get_model(model_info: dict) -> Any | Load model from extracted artifacts |
async infer(*args, **kwargs) -> Any | Run inference on inputs |
Class attributes:
| Attribute | Description |
|---|---|
action_name | Name for action discovery (e.g., 'inference') |
app | FastAPI app instance (decorators applied automatically by deploy()) |
infer_remote
Call the deployed serve endpoint for inference. Used by the entrypoint when an inference action is executed.
@classmethod
def infer_remote(cls, params: dict[str, Any], ctx: Any) -> Any
Params format:
| Key | Type | Required | Description |
|---|---|---|---|
model | int | str | No | Model ID for multiplexing |
method | str | No | HTTP method (default: 'post') |
json | dict | Yes | Request body sent to the serve endpoint |
# Inference params payload
params = {
'model': 34,
'method': 'post',
'json': {
'image_path': 'https://example.com/image.jpg',
'threshold': 0.5,
},
}
Batch Inference (@serve.batch)
Enable GPU-efficient batch inference by configuring batch in serve_options. Ray Serve automatically collects concurrent requests and passes them to infer_batch() as a list.
Configuration
Add the batch key to your config.yaml:
actions:
deployment:
entrypoint: plugin.inference.InferenceDeployment
method: job
serve_options:
num_replicas: 1
batch:
max_batch_size: 16 # Maximum requests to batch together
batch_wait_timeout_s: 0.1 # Seconds to wait for batch to fill
| Key | Type | Default | Description |
|---|---|---|---|
max_batch_size | int | 8 | Maximum number of requests per batch |
batch_wait_timeout_s | float | 0.1 | How long to wait for more requests before processing |
When batch is not configured, the deployment processes requests individually (backward compatible).
How It Works
Client sends N concurrent requests
↓
Ray Serve collects them (up to max_batch_size within batch_wait_timeout_s)
↓
Calls infer_batch([req1, req2, ..., reqN])
↓
GPU processes all inputs in a single forward pass
↓
Results distributed back to individual callers
Default Behavior
BaseServeDeployment.infer_batch() has a default implementation that falls back to calling infer() for each request individually. This means existing plugins work without modification — they just don't get GPU batching benefits.
# Default fallback (no override needed)
async def infer_batch(self, requests: list[Any]) -> list[Any]:
return [await self.infer(req) for req in requests]
GPU-Optimized Batch Inference
Override infer_batch() to process multiple inputs in a single GPU forward pass:
class YOLOServeDeployment(BaseServeDeployment):
action_name = 'inference'
app = app
async def _get_model(self, model_info: dict[str, Any]) -> Any:
from ultralytics import YOLO
return YOLO(model_info['path'] / 'best.pt')
@app.post('/')
async def infer(self, data: ImageData) -> dict[str, Any]:
"""Single request — calls infer_batch internally via @serve.batch."""
return await self.infer_batch(data)
async def infer_batch(self, requests: list[ImageData]) -> list[dict[str, Any]]:
"""Batch inference — GPU processes all images at once."""
model = await self.get_model()
images = [req.cv_image for req in requests]
results = model(images, conf=requests[0].threshold)
return [serialize(r) for r in results]
The @serve.batch decorator is applied automatically during deployment based on your config.yaml — do not apply it yourself.
Recommended max_batch_size by Model
| Model Type | Recommended max_batch_size | Notes |
|---|---|---|
| ResNet / EfficientNet (classification) | 32–64 | Lightweight, high throughput |
| YOLO (detection) | 8–16 | Moderate GPU memory usage |
| SAM / segmentation | 2–4 | Heavy GPU memory usage |
DeploymentContext
Context for deployment action step-based workflows.
@dataclass
class DeploymentContext(BaseStepContext):
params: dict[str, Any] = field(default_factory=dict)
model_id: int | None = None
model: dict[str, Any] | None = None
model_path: str | None = None
serve_app_name: str | None = None
serve_app_id: int | None = None
route_prefix: str | None = None
ray_actor_options: dict[str, Any] = field(default_factory=dict)
deployed: bool = False
| Attribute | Type | Description |
|---|---|---|
serve_app_name | str | None | Ray Serve application name |
serve_app_id | int | None | ID of created serve application |
route_prefix | str | None | URL route prefix for deployment |
ray_actor_options | dict[str, Any] | Ray actor configuration |
deployed | bool | Whether deployment succeeded |
create_serve_multiplexed_model_id
Create a JWT-encoded model ID for serve multiplexing.
def create_serve_multiplexed_model_id(
model_id: int | str,
token: str,
backend_url: str,
tenant: str | None = None,
) -> str
| Parameter | Type | Required | Description |
|---|---|---|---|
model_id | int | str | Yes | Model ID to encode |
token | str | Yes | User access token |
backend_url | str | Yes | Backend URL (used as JWT secret) |
tenant | str | None | No | Tenant identifier |
Returns: JWT-encoded model token string.
from synapse_sdk.plugins.actions.inference import create_serve_multiplexed_model_id
model_token = create_serve_multiplexed_model_id(
model_id=123,
token='user_access_token',
backend_url='https://api.example.com',
tenant='my-tenant',
)
# Use in request headers:
headers = {'serve_multiplexed_model_id': model_token}
Example: Model Deployment
Complete example of deploying a PyTorch model with Ray Serve:
from typing import Any
from fastapi import FastAPI
from pydantic import BaseModel
from synapse_sdk.plugins.actions.inference import (
BaseDeploymentAction,
BaseServeDeployment,
)
app = FastAPI()
class PyTorchInference(BaseServeDeployment):
"""PyTorch inference deployment.
The @serve.deployment and @serve.ingress(app) decorators are applied
automatically by BaseDeploymentAction.deploy().
"""
action_name = 'inference'
app = app
async def _get_model(self, model_info: dict[str, Any]) -> Any:
import torch
model_path = model_info['path'] / 'model.pt'
model = torch.load(model_path)
model.eval()
return model
@app.post('/')
async def infer(self, inputs: list[dict]) -> list[dict]:
model = await self.get_model()
import torch
results = []
for inp in inputs:
tensor = torch.tensor(inp['data'])
with torch.no_grad():
prediction = model(tensor)
results.append({'prediction': prediction.tolist()})
return results
class DeploymentParams(BaseModel):
model: int
num_gpus: int = 1
class MyDeploymentAction(BaseDeploymentAction[DeploymentParams]):
"""Deploy PyTorch model to Ray Serve."""
action_name = 'deployment'
entrypoint = PyTorchInference
def execute(self) -> dict[str, Any]:
# Initialize Ray
self.ray_init()
self.set_progress(1, 3, self.progress.INITIALIZE)
# Deploy to Ray Serve
self.deploy()
self.set_progress(2, 3, self.progress.DEPLOY)
# Register with backend
app_id = self.register_serve_application()
self.set_progress(3, 3, self.progress.REGISTER)
return {'serve_application': app_id}
Capacity Gate (SYN-7005)
BaseDeploymentAction.deploy() runs a pre-flight cluster capacity gate before invoking serve.run(...). The gate calls the agent's POST /resources/feasibility/ endpoint via ResourceClientMixin.check_feasibility and fails closed when the cluster cannot accept the deployment.
Branches
| Condition | Behavior | Exception / Log |
|---|---|---|
ctx.agent_client is None | Graceful skip — serve.run proceeds | WARNING: serve capacity gate skipped: agent_client not provided |
Agent response allowed=True | Proceed to serve.run | — |
Agent response allowed=False | Block deploy | RuntimeError('Serve deploy denied: insufficient cluster capacity ([reasons])') |
Agent unreachable (ClientError / ClientTimeoutError / TimeoutError / ConnectionError / OSError) | Block deploy (fail-closed) | RuntimeError('Serve deploy capacity check failed: agent unreachable ({error_type})') |
Malformed response (not dict, missing allowed) | Block deploy (fail-closed) | RuntimeError('Serve deploy capacity check failed: agent unreachable (MalformedResponse)') |
The original exception is preserved via __cause__ for diagnostics, but the wrapper message intentionally omits the cause string to avoid leaking agent URLs or tokens.
Request Payload
The gate derives its payload from ray_actor_options and serve_options:
| Key | Source | Default | Notes |
|---|---|---|---|
kind | hardcoded | 'serve' | Required by agent endpoint |
num_cpus | ray_actor_options['num_cpus'] | 1.0 | Coerced to float |
num_gpus | ray_actor_options['num_gpus'] | 0.0 | Coerced to float |
memory_bytes | ray_actor_options['memory'] | None | Coerced to int when present |
replicas | serve_options['num_replicas'] | 1 | Coerced to int |
metadata.plugin_code | get_serve_app_name() | — | e.g. '[email protected]' |
metadata.agent | SYNAPSE_AGENT_ID env var | '' | — |
Diagnostics
For troubleshooting "Serve deploy denied" errors, inspect reasons in the error message and cross-reference the agent's gatekeeping policy via get_gatekeeping_policy. For "agent unreachable" errors, the underlying exception type (ClientError / TimeoutError / etc.) is included to help isolate network vs. agent-side issues.
try:
self.deploy()
except RuntimeError as exc:
if 'insufficient cluster capacity' in str(exc):
# Cluster is full — surface to user / retry later
self.log('capacity_denied', {'error': str(exc)})
elif 'agent unreachable' in str(exc):
# Transient network / agent issue — original exception on __cause__
self.log('agent_unreachable', {'cause': type(exc.__cause__).__name__})
raise
The SDK emits a single English message. The backend converts these to localized ValidationError (ko / en) when the deploy is dispatched via the to_task serializer path.
The step-based deployment path (setup_steps registry) and the async client (AsyncAgentClient) capacity gate are deferred to follow-up tasks. Only the simple execute() → deploy() path is gated in this release.
Running Inference
Once deployed, call the serve endpoint via the inference action:
synapse plugin run inference --mode job --params '{"model": 34, "method": "post", "json": {"inputs": [{"data": [1, 2, 3]}]}}'
The entrypoint detects BaseServeDeployment subclasses and calls infer_remote(), which resolves the deployed endpoint's route prefix and forwards the request.
Best Practices
Model Caching
Cache loaded models to avoid repeated downloads:
class CachedInferenceAction(BaseInferenceAction[InferenceParams]):
_model_cache: dict[int, Any] = {}
def load_model_cached(self, model_id: int) -> Any:
if model_id not in self._model_cache:
model_info = self.load_model(model_id)
self._model_cache[model_id] = torch.load(model_info['path'] + '/model.pt')
return self._model_cache[model_id]
Batch Processing Optimization
For BaseDeploymentAction serving, use the batch config in serve_options to enable automatic GPU batching via Ray Serve. This is the recommended approach for real-time inference.
For BaseInferenceAction offline processing, batch inputs manually:
def execute(self) -> dict[str, Any]:
model = self.load_model_cached(self.params.model_id)
# Process in batches
batch_size = self.params.batch_size
results = []
for i in range(0, len(self.params.inputs), batch_size):
batch = self.params.inputs[i : i + batch_size]
batch_results = self._process_batch(model, batch)
results.extend(batch_results)
# Update progress
progress = min((i + batch_size) / len(self.params.inputs), 1.0)
self.set_progress(int(progress * 100), 100, self.progress.INFERENCE)
return {'results': results}
Error Handling
Handle model loading and inference errors gracefully:
def execute(self) -> dict[str, Any]:
try:
model_info = self.load_model(self.params.model_id)
except ValueError as e:
self.log('model_load_error', {'error': str(e)})
return {'error': f'Failed to load model: {e}'}
try:
results = self.infer(model_info, self.params.inputs)
except Exception as e:
self.log('inference_error', {'error': str(e)})
return {'error': f'Inference failed: {e}', 'partial_results': []}
return {'results': results}
Related
- Defining Actions - Core action base class
- Step-Based Workflows - Building complex workflows with steps
- Train Actions - Model training workflows