Skip to main content

Inference Actions

Use inference actions to build model inference workflows. Use BaseInferenceAction for batch inference and BaseDeploymentAction for REST API serving via Ray Serve.

Overview

Synapse SDK provides two base classes for inference workflows:

Base ClassPurposeUse Case
BaseInferenceActionBatch inferenceProcessing datasets, offline predictions
BaseDeploymentActionREST API servingReal-time inference endpoints via Ray Serve

Both classes support two execution modes:

  • Simple Mode: Override execute() directly for straightforward workflows
  • Step-Based Mode: Use setup_steps() to register workflow steps for complex pipelines

BaseInferenceAction

The base class for inference actions. Provides helper methods for model loading and inference workflows.

class BaseInferenceAction(BaseAction[P]):
category = PluginCategory.NEURAL_NET
progress = InferenceProgressCategories()

Progress Categories

Track inference progress with these standard categories:

CategoryValueDescription
MODEL_LOAD'model_load'Model loading and initialization
INFERENCE'inference'Running inference on inputs
POSTPROCESS'postprocess'Post-processing results
self.set_progress(1, 3, self.progress.MODEL_LOAD)
self.set_progress(2, 3, self.progress.INFERENCE)
self.set_progress(3, 3, self.progress.POSTPROCESS)

Helper Methods

get_model

Retrieve model metadata by ID.

def get_model(self, model_id: int) -> dict[str, Any]
ParameterTypeRequiredDescription
model_idintYesModel identifier

Returns: Model metadata dictionary including file URL.

model = self.get_model(123)
print(model['name'], model['file'])

download_model

Download and extract model artifacts.

def download_model(
self,
model_id: int,
output_dir: str | Path | None = None,
) -> Path
ParameterTypeRequiredDefaultDescription
model_idintYes-Model identifier
output_dirstr | Path | NoneNoNoneDirectory to extract model to. Uses tempdir if None

Returns: Path to extracted model directory.

model_path = self.download_model(123)
# model_path contains extracted model artifacts

load_model

Load model for inference. Downloads artifacts and returns model info with local path.

def load_model(self, model_id: int) -> dict[str, Any]
ParameterTypeRequiredDescription
model_idintYesModel identifier

Returns: Model metadata dict with 'path' key for local artifacts.

model_info = self.load_model(123)
model_path = model_info['path']
# Load your model framework here:
# model = torch.load(model_path / 'model.pt')

infer

Run inference on inputs. Override this method to implement your inference logic.

def infer(
self,
model: Any,
inputs: list[dict[str, Any]],
) -> list[dict[str, Any]]
ParameterTypeRequiredDescription
modelAnyYesLoaded model (framework-specific)
inputslist[dict[str, Any]]YesList of input dictionaries

Returns: List of result dictionaries.

def infer(self, model, inputs):
results = []
for inp in inputs:
prediction = model.predict(inp['image'])
results.append({'prediction': prediction})
return results

Execution Modes

Simple Mode

Override execute() directly for simple workflows:

from synapse_sdk.plugins.actions.inference import BaseInferenceAction
from pydantic import BaseModel

class InferenceParams(BaseModel):
model_id: int
inputs: list[dict]

class MyInferenceAction(BaseInferenceAction[InferenceParams]):
action_name = 'inference'

def execute(self) -> dict[str, Any]:
# Load model
model_info = self.load_model(self.params.model_id)
self.set_progress(1, 3, self.progress.MODEL_LOAD)

# Run inference
results = self.infer(model_info, self.params.inputs)
self.set_progress(2, 3, self.progress.INFERENCE)

# Post-process
processed = self._postprocess(results)
self.set_progress(3, 3, self.progress.POSTPROCESS)

return {'results': processed}

def infer(self, model, inputs):
import torch

model_obj = torch.load(model['path'] + '/model.pt')
results = []
for inp in inputs:
pred = model_obj(inp['tensor'])
results.append({'prediction': pred.tolist()})
return results

Step-Based Mode

Use setup_steps() to register workflow steps for complex pipelines:

from synapse_sdk.plugins.actions.inference import (
BaseInferenceAction,
InferenceContext,
)
from synapse_sdk.plugins.steps import BaseStep, StepResult, StepRegistry

class LoadModelStep(BaseStep[InferenceContext]):
@property
def name(self) -> str:
return 'load_model'

@property
def progress_weight(self) -> float:
return 0.3

def execute(self, context: InferenceContext) -> StepResult:
# Load model using context
import torch

model_path = context.model_path
context.model = torch.load(f'{model_path}/model.pt')
return StepResult(success=True)

class InferenceStep(BaseStep[InferenceContext]):
@property
def name(self) -> str:
return 'inference'

@property
def progress_weight(self) -> float:
return 0.7

def execute(self, context: InferenceContext) -> StepResult:
for request in context.requests:
prediction = context.model(request['input'])
context.results.append({'prediction': prediction})
context.processed_count += 1
return StepResult(success=True)

class MyInferenceAction(BaseInferenceAction[InferenceParams]):
action_name = 'inference'

def setup_steps(self, registry: StepRegistry[InferenceContext]) -> None:
registry.register(LoadModelStep())
registry.register(InferenceStep())

InferenceContext

Context for inference action step-based workflows. Extends BaseStepContext with inference-specific state.

@dataclass
class InferenceContext(BaseStepContext):
params: dict[str, Any] = field(default_factory=dict)
model_id: int | None = None
model: dict[str, Any] | None = None
model_path: str | None = None
requests: list[dict[str, Any]] = field(default_factory=list)
results: list[dict[str, Any]] = field(default_factory=list)
batch_size: int = 1
processed_count: int = 0
AttributeTypeDescription
paramsdict[str, Any]Action parameters
model_idint | NoneID of the model being used
modeldict[str, Any] | NoneLoaded model information from backend
model_pathstr | NoneLocal path to downloaded model
requestslist[dict[str, Any]]Input requests to process
resultslist[dict[str, Any]]Inference results
batch_sizeintBatch size for processing
processed_countintNumber of processed items

Example: Batch Inference

Complete example of a batch inference action with PyTorch:

from pathlib import Path
from typing import Any

import torch
from pydantic import BaseModel

from synapse_sdk.plugins.actions.inference import BaseInferenceAction


class BatchInferenceParams(BaseModel):
model_id: int
inputs: list[dict[str, Any]]
batch_size: int = 32


class BatchInferenceAction(BaseInferenceAction[BatchInferenceParams]):
"""Batch inference action for PyTorch models."""

action_name = 'batch_inference'

def execute(self) -> dict[str, Any]:
# Step 1: Load model
model_info = self.load_model(self.params.model_id)
model_path = Path(model_info['path'])
model = torch.load(model_path / 'model.pt')
model.eval()
self.set_progress(1, 3, self.progress.MODEL_LOAD)

# Step 2: Run inference in batches
results = []
inputs = self.params.inputs
batch_size = self.params.batch_size

for i in range(0, len(inputs), batch_size):
batch = inputs[i:i + batch_size]
batch_tensors = torch.stack([
torch.tensor(inp['data']) for inp in batch
])

with torch.no_grad():
predictions = model(batch_tensors)

for pred in predictions:
results.append({'prediction': pred.tolist()})

self.set_progress(2, 3, self.progress.INFERENCE)

# Step 3: Post-process
self.set_progress(3, 3, self.progress.POSTPROCESS)

return {
'results': results,
'processed_count': len(results),
}

def infer(self, model, inputs):
# Optional: Override for custom inference logic
pass

Deployment Actions

Use deployment actions to serve REST APIs via Ray Serve. Use BaseDeploymentAction to deploy inference endpoints.

BaseDeploymentAction

Base class for Ray Serve deployment actions. Handles Ray initialization, deployment creation, and backend registration.

class BaseDeploymentAction(BaseAction[P]):
progress = DeploymentProgressCategories()
entrypoint: type | None = None # Set to your serve deployment class

DeploymentProgressCategories

CategoryValueDescription
INITIALIZE'initialize'Ray cluster initialization
DEPLOY'deploy'Deploying to Ray Serve
REGISTER'register'Registering with backend

Deployment Methods

ray_init

Initialize Ray cluster connection.

def ray_init(self, **kwargs: Any) -> None

deploy

Deploy the inference endpoint to Ray Serve.

def deploy(self) -> None

register_serve_application

Register the serve application with the backend.

def register_serve_application(self) -> int | None

Returns: Serve application ID if created, None otherwise.

Configuration Methods

Override these methods to customize deployment:

MethodDefaultDescription
get_serve_app_name()SYNAPSE_PLUGIN_RELEASE_CODE env varServe application name
get_route_prefix()SYNAPSE_PLUGIN_RELEASE_CHECKSUM env varURL route prefix
get_ray_actor_options()Extract from paramsRay actor options (num_cpus, num_gpus, memory)
get_runtime_env(){}Ray runtime environment

BaseServeDeployment

Base class for Ray Serve inference deployments. Inherits from BaseAction and provides model loading with multiplexing support.

class BaseServeDeployment(BaseAction):
def __init__(self, backend_url: str) -> None:
self.backend_url = backend_url
self._model_cache: dict[str, Any] = {}

Abstract methods to implement:

MethodDescription
async _get_model(model_info: dict) -> AnyLoad model from extracted artifacts
async infer(*args, **kwargs) -> AnyRun inference on inputs

Class attributes:

AttributeDescription
action_nameName for action discovery (e.g., 'inference')
appFastAPI app instance (decorators applied automatically by deploy())

infer_remote

Call the deployed serve endpoint for inference. Used by the entrypoint when an inference action is executed.

@classmethod
def infer_remote(cls, params: dict[str, Any], ctx: Any) -> Any

Params format:

KeyTypeRequiredDescription
modelint | strNoModel ID for multiplexing
methodstrNoHTTP method (default: 'post')
jsondictYesRequest body sent to the serve endpoint
# Inference params payload
params = {
"model": 34,
"method": "post",
"json": {
"image_path": "https://example.com/image.jpg",
"threshold": 0.5,
}
}

DeploymentContext

Context for deployment action step-based workflows.

@dataclass
class DeploymentContext(BaseStepContext):
params: dict[str, Any] = field(default_factory=dict)
model_id: int | None = None
model: dict[str, Any] | None = None
model_path: str | None = None
serve_app_name: str | None = None
serve_app_id: int | None = None
route_prefix: str | None = None
ray_actor_options: dict[str, Any] = field(default_factory=dict)
deployed: bool = False
AttributeTypeDescription
serve_app_namestr | NoneRay Serve application name
serve_app_idint | NoneID of created serve application
route_prefixstr | NoneURL route prefix for deployment
ray_actor_optionsdict[str, Any]Ray actor configuration
deployedboolWhether deployment succeeded

create_serve_multiplexed_model_id

Create a JWT-encoded model ID for serve multiplexing.

def create_serve_multiplexed_model_id(
model_id: int | str,
token: str,
backend_url: str,
tenant: str | None = None,
) -> str
ParameterTypeRequiredDescription
model_idint | strYesModel ID to encode
tokenstrYesUser access token
backend_urlstrYesBackend URL (used as JWT secret)
tenantstr | NoneNoTenant identifier

Returns: JWT-encoded model token string.

from synapse_sdk.plugins.actions.inference import create_serve_multiplexed_model_id

model_token = create_serve_multiplexed_model_id(
model_id=123,
token='user_access_token',
backend_url='https://api.example.com',
tenant='my-tenant',
)
# Use in request headers:
headers = {'serve_multiplexed_model_id': model_token}

Example: Model Deployment

Complete example of deploying a PyTorch model with Ray Serve:

from typing import Any

from fastapi import FastAPI
from pydantic import BaseModel

from synapse_sdk.plugins.actions.inference import (
BaseDeploymentAction,
BaseServeDeployment,
)

app = FastAPI()


class PyTorchInference(BaseServeDeployment):
"""PyTorch inference deployment.

The @serve.deployment and @serve.ingress(app) decorators are applied
automatically by BaseDeploymentAction.deploy().
"""

action_name = 'inference'
app = app

async def _get_model(self, model_info: dict[str, Any]) -> Any:
import torch

model_path = model_info['path'] / 'model.pt'
model = torch.load(model_path)
model.eval()
return model

@app.post('/')
async def infer(self, inputs: list[dict]) -> list[dict]:
model = await self.get_model()

import torch

results = []
for inp in inputs:
tensor = torch.tensor(inp['data'])
with torch.no_grad():
prediction = model(tensor)
results.append({'prediction': prediction.tolist()})

return results


class DeploymentParams(BaseModel):
model: int
num_gpus: int = 1


class MyDeploymentAction(BaseDeploymentAction[DeploymentParams]):
"""Deploy PyTorch model to Ray Serve."""

action_name = 'deployment'
entrypoint = PyTorchInference

def execute(self) -> dict[str, Any]:
# Initialize Ray
self.ray_init()
self.set_progress(1, 3, self.progress.INITIALIZE)

# Deploy to Ray Serve
self.deploy()
self.set_progress(2, 3, self.progress.DEPLOY)

# Register with backend
app_id = self.register_serve_application()
self.set_progress(3, 3, self.progress.REGISTER)

return {'serve_application': app_id}

Running Inference

Once deployed, call the serve endpoint via the inference action:

synapse plugin run inference --mode job --params '{"model": 34, "method": "post", "json": {"inputs": [{"data": [1, 2, 3]}]}}'

The entrypoint detects BaseServeDeployment subclasses and calls infer_remote(), which resolves the deployed endpoint's route prefix and forwards the request.


Best Practices

Model Caching

Cache loaded models to avoid repeated downloads:

class CachedInferenceAction(BaseInferenceAction[InferenceParams]):
_model_cache: dict[int, Any] = {}

def load_model_cached(self, model_id: int) -> Any:
if model_id not in self._model_cache:
model_info = self.load_model(model_id)
self._model_cache[model_id] = torch.load(model_info['path'] + '/model.pt')
return self._model_cache[model_id]

Batch Processing Optimization

Process inputs in batches for better throughput:

def execute(self) -> dict[str, Any]:
model = self.load_model_cached(self.params.model_id)

# Process in batches
batch_size = self.params.batch_size
results = []

for i in range(0, len(self.params.inputs), batch_size):
batch = self.params.inputs[i:i + batch_size]
batch_results = self._process_batch(model, batch)
results.extend(batch_results)

# Update progress
progress = min((i + batch_size) / len(self.params.inputs), 1.0)
self.set_progress(int(progress * 100), 100, self.progress.INFERENCE)

return {'results': results}

Error Handling

Handle model loading and inference errors gracefully:

def execute(self) -> dict[str, Any]:
try:
model_info = self.load_model(self.params.model_id)
except ValueError as e:
self.log('model_load_error', {'error': str(e)})
return {'error': f'Failed to load model: {e}'}

try:
results = self.infer(model_info, self.params.inputs)
except Exception as e:
self.log('inference_error', {'error': str(e)})
return {'error': f'Inference failed: {e}', 'partial_results': []}

return {'results': results}