Skip to main content

Inference Actions

Use inference actions to build model inference workflows. Use BaseInferenceAction for batch inference and BaseDeploymentAction for REST API serving via Ray Serve.

Overview

Synapse SDK provides two base classes for inference workflows:

Base ClassPurposeUse Case
BaseInferenceActionBatch inferenceProcessing datasets, offline predictions
BaseDeploymentActionREST API servingReal-time inference endpoints via Ray Serve

Both classes support two execution modes:

  • Simple Mode: Override execute() directly for straightforward workflows
  • Step-Based Mode: Use setup_steps() to register workflow steps for complex pipelines

BaseInferenceAction

The base class for inference actions. Provides helper methods for model loading and inference workflows.

class BaseInferenceAction(BaseAction[P]):
category = PluginCategory.NEURAL_NET
progress = InferenceProgressCategories()

Progress Categories

Track inference progress with these standard categories:

CategoryValueDescription
MODEL_LOAD'model_load'Model loading and initialization
INFERENCE'inference'Running inference on inputs
POSTPROCESS'postprocess'Post-processing results
self.set_progress(1, 3, self.progress.MODEL_LOAD)
self.set_progress(2, 3, self.progress.INFERENCE)
self.set_progress(3, 3, self.progress.POSTPROCESS)

Helper Methods

get_model

Retrieve model metadata by ID.

def get_model(self, model_id: int) -> dict[str, Any]
ParameterTypeRequiredDescription
model_idintYesModel identifier

Returns: Model metadata dictionary including file URL.

model = self.get_model(123)
print(model['name'], model['file'])

download_model

Download and extract model artifacts.

def download_model(
self,
model_id: int,
output_dir: str | Path | None = None,
) -> Path
ParameterTypeRequiredDefaultDescription
model_idintYes-Model identifier
output_dirstr | Path | NoneNoNoneDirectory to extract model to. Uses tempdir if None

Returns: Path to extracted model directory.

model_path = self.download_model(123)
# model_path contains extracted model artifacts

load_model

Load model for inference. Downloads artifacts and returns model info with local path.

def load_model(self, model_id: int) -> dict[str, Any]
ParameterTypeRequiredDescription
model_idintYesModel identifier

Returns: Model metadata dict with 'path' key for local artifacts.

model_info = self.load_model(123)
model_path = model_info['path']
# Load your model framework here:
# model = torch.load(model_path / 'model.pt')

infer

Run inference on inputs. Override this method to implement your inference logic.

def infer(
self,
model: Any,
inputs: list[dict[str, Any]],
) -> list[dict[str, Any]]
ParameterTypeRequiredDescription
modelAnyYesLoaded model (framework-specific)
inputslist[dict[str, Any]]YesList of input dictionaries

Returns: List of result dictionaries.

def infer(self, model, inputs):
results = []
for inp in inputs:
prediction = model.predict(inp['image'])
results.append({'prediction': prediction})
return results

Execution Modes

Simple Mode

Override execute() directly for simple workflows:

from synapse_sdk.plugins.actions.inference import BaseInferenceAction
from pydantic import BaseModel


class InferenceParams(BaseModel):
model_id: int
inputs: list[dict]


class MyInferenceAction(BaseInferenceAction[InferenceParams]):
action_name = 'inference'

def execute(self) -> dict[str, Any]:
# Load model
model_info = self.load_model(self.params.model_id)
self.set_progress(1, 3, self.progress.MODEL_LOAD)

# Run inference
results = self.infer(model_info, self.params.inputs)
self.set_progress(2, 3, self.progress.INFERENCE)

# Post-process
processed = self._postprocess(results)
self.set_progress(3, 3, self.progress.POSTPROCESS)

return {'results': processed}

def infer(self, model, inputs):
import torch

model_obj = torch.load(model['path'] + '/model.pt')
results = []
for inp in inputs:
pred = model_obj(inp['tensor'])
results.append({'prediction': pred.tolist()})
return results

Step-Based Mode

Use setup_steps() to register workflow steps for complex pipelines:

from synapse_sdk.plugins.actions.inference import (
BaseInferenceAction,
InferenceContext,
)
from synapse_sdk.plugins.steps import BaseStep, StepResult, StepRegistry


class LoadModelStep(BaseStep[InferenceContext]):
@property
def name(self) -> str:
return 'load_model'

@property
def progress_weight(self) -> float:
return 0.3

def execute(self, context: InferenceContext) -> StepResult:
# Load model using context
import torch

model_path = context.model_path
context.model = torch.load(f'{model_path}/model.pt')
return StepResult(success=True)


class InferenceStep(BaseStep[InferenceContext]):
@property
def name(self) -> str:
return 'inference'

@property
def progress_weight(self) -> float:
return 0.7

def execute(self, context: InferenceContext) -> StepResult:
for request in context.requests:
prediction = context.model(request['input'])
context.results.append({'prediction': prediction})
context.processed_count += 1
return StepResult(success=True)


class MyInferenceAction(BaseInferenceAction[InferenceParams]):
action_name = 'inference'

def setup_steps(self, registry: StepRegistry[InferenceContext]) -> None:
registry.register(LoadModelStep())
registry.register(InferenceStep())

InferenceContext

Context for inference action step-based workflows. Extends BaseStepContext with inference-specific state.

@dataclass
class InferenceContext(BaseStepContext):
params: dict[str, Any] = field(default_factory=dict)
model_id: int | None = None
model: dict[str, Any] | None = None
model_path: str | None = None
requests: list[dict[str, Any]] = field(default_factory=list)
results: list[dict[str, Any]] = field(default_factory=list)
batch_size: int = 1
processed_count: int = 0
AttributeTypeDescription
paramsdict[str, Any]Action parameters
model_idint | NoneID of the model being used
modeldict[str, Any] | NoneLoaded model information from backend
model_pathstr | NoneLocal path to downloaded model
requestslist[dict[str, Any]]Input requests to process
resultslist[dict[str, Any]]Inference results
batch_sizeintBatch size for processing
processed_countintNumber of processed items

Example: Batch Inference

Complete example of a batch inference action with PyTorch:

from pathlib import Path
from typing import Any

import torch
from pydantic import BaseModel

from synapse_sdk.plugins.actions.inference import BaseInferenceAction


class BatchInferenceParams(BaseModel):
model_id: int
inputs: list[dict[str, Any]]
batch_size: int = 32


class BatchInferenceAction(BaseInferenceAction[BatchInferenceParams]):
"""Batch inference action for PyTorch models."""

action_name = 'batch_inference'

def execute(self) -> dict[str, Any]:
# Step 1: Load model
model_info = self.load_model(self.params.model_id)
model_path = Path(model_info['path'])
model = torch.load(model_path / 'model.pt')
model.eval()
self.set_progress(1, 3, self.progress.MODEL_LOAD)

# Step 2: Run inference in batches
results = []
inputs = self.params.inputs
batch_size = self.params.batch_size

for i in range(0, len(inputs), batch_size):
batch = inputs[i : i + batch_size]
batch_tensors = torch.stack([torch.tensor(inp['data']) for inp in batch])

with torch.no_grad():
predictions = model(batch_tensors)

for pred in predictions:
results.append({'prediction': pred.tolist()})

self.set_progress(2, 3, self.progress.INFERENCE)

# Step 3: Post-process
self.set_progress(3, 3, self.progress.POSTPROCESS)

return {
'results': results,
'processed_count': len(results),
}

def infer(self, model, inputs):
# Optional: Override for custom inference logic
pass

Deployment Actions

Use deployment actions to serve REST APIs via Ray Serve. Use BaseDeploymentAction to deploy inference endpoints.

BaseDeploymentAction

Base class for Ray Serve deployment actions. Handles Ray initialization, deployment creation, and backend registration.

class BaseDeploymentAction(BaseAction[P]):
progress = DeploymentProgressCategories()
entrypoint: type | None = None # Set to your serve deployment class

DeploymentProgressCategories

CategoryValueDescription
INITIALIZE'initialize'Ray cluster initialization
DEPLOY'deploy'Deploying to Ray Serve
REGISTER'register'Registering with backend

Deployment Methods

ray_init

Initialize Ray cluster connection.

def ray_init(self, **kwargs: Any) -> None

deploy

Deploy the inference endpoint to Ray Serve.

def deploy(self) -> None

Internally calls a pre-flight cluster capacity gate (_check_serve_capacity) immediately before serve.run(...). See Capacity Gate below.

register_serve_application

Register the serve application with the backend.

def register_serve_application(self) -> int | None

Returns: Serve application ID if created, None otherwise.

Configuration Methods

Override these methods to customize deployment:

MethodDefaultDescription
get_serve_app_name()SYNAPSE_PLUGIN_RELEASE_CODE env varServe application name
get_route_prefix()SYNAPSE_PLUGIN_RELEASE_CHECKSUM env varURL route prefix
get_ray_actor_options()Extract from paramsRay actor options (num_cpus, num_gpus, memory)
get_runtime_env(){}Ray runtime environment

BaseServeDeployment

Base class for Ray Serve inference deployments. Inherits from BaseAction and provides model loading with multiplexing support.

class BaseServeDeployment(BaseAction):
def __init__(self, backend_url: str) -> None:
self.backend_url = backend_url
self._model_cache: dict[str, Any] = {}

Abstract methods to implement:

MethodDescription
async _get_model(model_info: dict) -> AnyLoad model from extracted artifacts
async infer(*args, **kwargs) -> AnyRun inference on inputs

Class attributes:

AttributeDescription
action_nameName for action discovery (e.g., 'inference')
appFastAPI app instance (decorators applied automatically by deploy())

infer_remote

Call the deployed serve endpoint for inference. Used by the entrypoint when an inference action is executed.

@classmethod
def infer_remote(cls, params: dict[str, Any], ctx: Any) -> Any

Params format:

KeyTypeRequiredDescription
modelint | strNoModel ID for multiplexing
methodstrNoHTTP method (default: 'post')
jsondictYesRequest body sent to the serve endpoint
# Inference params payload
params = {
'model': 34,
'method': 'post',
'json': {
'image_path': 'https://example.com/image.jpg',
'threshold': 0.5,
},
}

Batch Inference (@serve.batch)

Enable GPU-efficient batch inference by configuring batch in serve_options. Ray Serve automatically collects concurrent requests and passes them to infer_batch() as a list.

Configuration

Add the batch key to your config.yaml:

actions:
deployment:
entrypoint: plugin.inference.InferenceDeployment
method: job
serve_options:
num_replicas: 1
batch:
max_batch_size: 16 # Maximum requests to batch together
batch_wait_timeout_s: 0.1 # Seconds to wait for batch to fill
KeyTypeDefaultDescription
max_batch_sizeint8Maximum number of requests per batch
batch_wait_timeout_sfloat0.1How long to wait for more requests before processing

When batch is not configured, the deployment processes requests individually (backward compatible).

How It Works

Client sends N concurrent requests

Ray Serve collects them (up to max_batch_size within batch_wait_timeout_s)

Calls infer_batch([req1, req2, ..., reqN])

GPU processes all inputs in a single forward pass

Results distributed back to individual callers

Default Behavior

BaseServeDeployment.infer_batch() has a default implementation that falls back to calling infer() for each request individually. This means existing plugins work without modification — they just don't get GPU batching benefits.

# Default fallback (no override needed)
async def infer_batch(self, requests: list[Any]) -> list[Any]:
return [await self.infer(req) for req in requests]

GPU-Optimized Batch Inference

Override infer_batch() to process multiple inputs in a single GPU forward pass:

class YOLOServeDeployment(BaseServeDeployment):
action_name = 'inference'
app = app

async def _get_model(self, model_info: dict[str, Any]) -> Any:
from ultralytics import YOLO

return YOLO(model_info['path'] / 'best.pt')

@app.post('/')
async def infer(self, data: ImageData) -> dict[str, Any]:
"""Single request — calls infer_batch internally via @serve.batch."""
return await self.infer_batch(data)

async def infer_batch(self, requests: list[ImageData]) -> list[dict[str, Any]]:
"""Batch inference — GPU processes all images at once."""
model = await self.get_model()
images = [req.cv_image for req in requests]
results = model(images, conf=requests[0].threshold)
return [serialize(r) for r in results]
note

The @serve.batch decorator is applied automatically during deployment based on your config.yaml — do not apply it yourself.

Model TypeRecommended max_batch_sizeNotes
ResNet / EfficientNet (classification)32–64Lightweight, high throughput
YOLO (detection)8–16Moderate GPU memory usage
SAM / segmentation2–4Heavy GPU memory usage

DeploymentContext

Context for deployment action step-based workflows.

@dataclass
class DeploymentContext(BaseStepContext):
params: dict[str, Any] = field(default_factory=dict)
model_id: int | None = None
model: dict[str, Any] | None = None
model_path: str | None = None
serve_app_name: str | None = None
serve_app_id: int | None = None
route_prefix: str | None = None
ray_actor_options: dict[str, Any] = field(default_factory=dict)
deployed: bool = False
AttributeTypeDescription
serve_app_namestr | NoneRay Serve application name
serve_app_idint | NoneID of created serve application
route_prefixstr | NoneURL route prefix for deployment
ray_actor_optionsdict[str, Any]Ray actor configuration
deployedboolWhether deployment succeeded

create_serve_multiplexed_model_id

Create a JWT-encoded model ID for serve multiplexing.

def create_serve_multiplexed_model_id(
model_id: int | str,
token: str,
backend_url: str,
tenant: str | None = None,
) -> str
ParameterTypeRequiredDescription
model_idint | strYesModel ID to encode
tokenstrYesUser access token
backend_urlstrYesBackend URL (used as JWT secret)
tenantstr | NoneNoTenant identifier

Returns: JWT-encoded model token string.

from synapse_sdk.plugins.actions.inference import create_serve_multiplexed_model_id

model_token = create_serve_multiplexed_model_id(
model_id=123,
token='user_access_token',
backend_url='https://api.example.com',
tenant='my-tenant',
)
# Use in request headers:
headers = {'serve_multiplexed_model_id': model_token}

Example: Model Deployment

Complete example of deploying a PyTorch model with Ray Serve:

from typing import Any

from fastapi import FastAPI
from pydantic import BaseModel

from synapse_sdk.plugins.actions.inference import (
BaseDeploymentAction,
BaseServeDeployment,
)

app = FastAPI()


class PyTorchInference(BaseServeDeployment):
"""PyTorch inference deployment.

The @serve.deployment and @serve.ingress(app) decorators are applied
automatically by BaseDeploymentAction.deploy().
"""

action_name = 'inference'
app = app

async def _get_model(self, model_info: dict[str, Any]) -> Any:
import torch

model_path = model_info['path'] / 'model.pt'
model = torch.load(model_path)
model.eval()
return model

@app.post('/')
async def infer(self, inputs: list[dict]) -> list[dict]:
model = await self.get_model()

import torch

results = []
for inp in inputs:
tensor = torch.tensor(inp['data'])
with torch.no_grad():
prediction = model(tensor)
results.append({'prediction': prediction.tolist()})

return results


class DeploymentParams(BaseModel):
model: int
num_gpus: int = 1


class MyDeploymentAction(BaseDeploymentAction[DeploymentParams]):
"""Deploy PyTorch model to Ray Serve."""

action_name = 'deployment'
entrypoint = PyTorchInference

def execute(self) -> dict[str, Any]:
# Initialize Ray
self.ray_init()
self.set_progress(1, 3, self.progress.INITIALIZE)

# Deploy to Ray Serve
self.deploy()
self.set_progress(2, 3, self.progress.DEPLOY)

# Register with backend
app_id = self.register_serve_application()
self.set_progress(3, 3, self.progress.REGISTER)

return {'serve_application': app_id}

Capacity Gate (SYN-7005)

BaseDeploymentAction.deploy() runs a pre-flight cluster capacity gate before invoking serve.run(...). The gate calls the agent's POST /resources/feasibility/ endpoint via ResourceClientMixin.check_feasibility and fails closed when the cluster cannot accept the deployment.

Branches

ConditionBehaviorException / Log
ctx.agent_client is NoneGraceful skip — serve.run proceedsWARNING: serve capacity gate skipped: agent_client not provided
Agent response allowed=TrueProceed to serve.run
Agent response allowed=FalseBlock deployRuntimeError('Serve deploy denied: insufficient cluster capacity ([reasons])')
Agent unreachable (ClientError / ClientTimeoutError / TimeoutError / ConnectionError / OSError)Block deploy (fail-closed)RuntimeError('Serve deploy capacity check failed: agent unreachable ({error_type})')
Malformed response (not dict, missing allowed)Block deploy (fail-closed)RuntimeError('Serve deploy capacity check failed: agent unreachable (MalformedResponse)')

The original exception is preserved via __cause__ for diagnostics, but the wrapper message intentionally omits the cause string to avoid leaking agent URLs or tokens.

Request Payload

The gate derives its payload from ray_actor_options and serve_options:

KeySourceDefaultNotes
kindhardcoded'serve'Required by agent endpoint
num_cpusray_actor_options['num_cpus']1.0Coerced to float
num_gpusray_actor_options['num_gpus']0.0Coerced to float
memory_bytesray_actor_options['memory']NoneCoerced to int when present
replicasserve_options['num_replicas']1Coerced to int
metadata.plugin_codeget_serve_app_name()e.g. '[email protected]'
metadata.agentSYNAPSE_AGENT_ID env var''

Diagnostics

For troubleshooting "Serve deploy denied" errors, inspect reasons in the error message and cross-reference the agent's gatekeeping policy via get_gatekeeping_policy. For "agent unreachable" errors, the underlying exception type (ClientError / TimeoutError / etc.) is included to help isolate network vs. agent-side issues.

try:
self.deploy()
except RuntimeError as exc:
if 'insufficient cluster capacity' in str(exc):
# Cluster is full — surface to user / retry later
self.log('capacity_denied', {'error': str(exc)})
elif 'agent unreachable' in str(exc):
# Transient network / agent issue — original exception on __cause__
self.log('agent_unreachable', {'cause': type(exc.__cause__).__name__})
raise
Localization

The SDK emits a single English message. The backend converts these to localized ValidationError (ko / en) when the deploy is dispatched via the to_task serializer path.

Out of Scope (SYN-7005)

The step-based deployment path (setup_steps registry) and the async client (AsyncAgentClient) capacity gate are deferred to follow-up tasks. Only the simple execute()deploy() path is gated in this release.

Running Inference

Once deployed, call the serve endpoint via the inference action:

synapse plugin run inference --mode job --params '{"model": 34, "method": "post", "json": {"inputs": [{"data": [1, 2, 3]}]}}'

The entrypoint detects BaseServeDeployment subclasses and calls infer_remote(), which resolves the deployed endpoint's route prefix and forwards the request.


Best Practices

Model Caching

Cache loaded models to avoid repeated downloads:

class CachedInferenceAction(BaseInferenceAction[InferenceParams]):
_model_cache: dict[int, Any] = {}

def load_model_cached(self, model_id: int) -> Any:
if model_id not in self._model_cache:
model_info = self.load_model(model_id)
self._model_cache[model_id] = torch.load(model_info['path'] + '/model.pt')
return self._model_cache[model_id]

Batch Processing Optimization

For BaseDeploymentAction serving, use the batch config in serve_options to enable automatic GPU batching via Ray Serve. This is the recommended approach for real-time inference.

For BaseInferenceAction offline processing, batch inputs manually:

def execute(self) -> dict[str, Any]:
model = self.load_model_cached(self.params.model_id)

# Process in batches
batch_size = self.params.batch_size
results = []

for i in range(0, len(self.params.inputs), batch_size):
batch = self.params.inputs[i : i + batch_size]
batch_results = self._process_batch(model, batch)
results.extend(batch_results)

# Update progress
progress = min((i + batch_size) / len(self.params.inputs), 1.0)
self.set_progress(int(progress * 100), 100, self.progress.INFERENCE)

return {'results': results}

Error Handling

Handle model loading and inference errors gracefully:

def execute(self) -> dict[str, Any]:
try:
model_info = self.load_model(self.params.model_id)
except ValueError as e:
self.log('model_load_error', {'error': str(e)})
return {'error': f'Failed to load model: {e}'}

try:
results = self.infer(model_info, self.params.inputs)
except Exception as e:
self.log('inference_error', {'error': str(e)})
return {'error': f'Inference failed: {e}', 'partial_results': []}

return {'results': results}