Work with ML training jobs with Viam's ML training API
The ML training API allows you to get information about and cancel ML training jobs running on Viam.
The ML training client API supports the following methods:
| Method Name | Description |
|---|---|
SubmitTrainingJob | Submit a training job. |
SubmitCustomTrainingJob | Submit a training job from a custom training script. |
GetTrainingJob | Get training job metadata. |
ListTrainingJobs | Get training job metadata for all jobs within an organization. |
CancelTrainingJob | Cancel the specified training job. |
DeleteCompletedTrainingJob | Delete a completed training job from the database, whether the job succeeded or failed. |
Establish a connection
To use the ML training client API, you need to instantiate a ViamClient and then instantiate an MLTrainingClient.
You need an API key and API key ID with Org owner permissions to use the MLTraining client API. To get an API key (and corresponding ID), use the web UI to the Viam CLI.
import asyncio
from viam.rpc.dial import DialOptions, Credentials
from viam.app.viam_client import ViamClient
async def connect() -> ViamClient:
dial_options = DialOptions(
credentials=Credentials(
type="api-key",
# TODO: Replace "<API-KEY>" (including brackets) with your machine's
# API key
payload='<API-KEY>',
),
# TODO: Replace "<API-KEY-ID>" (including brackets) with your machine's
# API key ID
auth_entity='<API-KEY-ID>'
)
return await ViamClient.create_from_dial_options(dial_options)
async def main():
# Make a ViamClient
async with await connect() as viam_client:
# Instantiate an MLTrainingClient to run ML training client API methods on
ml_training_client = viam_client.ml_training_client
if __name__ == '__main__':
asyncio.run(main())
package main
import (
"context"
"go.viam.com/rdk/app"
"go.viam.com/rdk/logging"
)
func main() {
logger := logging.NewDebugLogger("client")
ctx := context.Background()
// TODO: Replace "<API-KEY>" (including brackets) with your machine's API key
// TODO: Replace "<API-KEY-ID>" (including brackets) with your machine's
// API key ID
viamClient, err := app.CreateViamClientWithAPIKey(
ctx, app.Options{}, "<API-KEY>", "<API-KEY-ID>", logger)
if err != nil {
logger.Fatal(err)
}
defer viamClient.Close()
mlTrainingClient := viamClient.MLTrainingClient()
}
async function connect(): Promise<VIAM.ViamClient> {
// TODO: Replace "<API-KEY-ID>" (including brackets) with your machine's
// API key ID
const API_KEY_ID = "<API-KEY-ID>";
// TODO: Replace "<API-KEY>" (including brackets) with your machine's API key
const API_KEY = "<API-KEY>";
const opts: VIAM.ViamClientOptions = {
serviceHost: "https://app.viam.com:443",
credentials: {
type: "api-key",
authEntity: API_KEY_ID,
payload: API_KEY,
},
};
const client = await VIAM.createViamClient(opts);
return client;
}
const viamClient = await connect();
const mlTrainingClient = viamClient.mlTrainingClient;
Once you have instantiated an MLTrainingClient, you can run the following API methods against the MLTrainingClient object (named ml_training_client in the examples).
API
SubmitTrainingJob
Submit a training job.
Parameters:
org_id(str) (required): The ID of the organization to submit the training job to. To retrieve this, expand your organization’s dropdown in the top right corner on Viam, select Settings, and copy Organization ID.dataset_id(str) (required): The ID of the dataset to train the ML model on. To retrieve this, navigate to your dataset’s page, click … in the left-hand menu, and click Copy dataset ID.model_name(str) (required): the model name.model_version(str) (required): The version of the ML model you’re training. This string must be unique from any previous versions you’ve set.model_type(viam.proto.app.mltraining.ModelType.ValueType) (required): The type of the ML model. Options:ModelType.MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION,ModelType.MODEL_TYPE_MULTI_LABEL_CLASSIFICATION,ModelType.MODEL_TYPE_OBJECT_DETECTION.tags(List[str]) (required): the labels to train the model on.
Returns:
- (str): : the ID of the training job.
Example:
from viam.proto.app.mltraining import ModelType
job_id = await ml_training_client.submit_training_job(
org_id="<organization-id>",
dataset_id="<dataset-id>",
model_name="<your-model-name>",
model_version="1",
model_type=ModelType.MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION,
tags=["tag1", "tag2"]
)
For more information, see the Python SDK Docs.
Parameters:
ctx(Context): A Context carries a deadline, a cancellation signal, and other values across API boundaries.args(SubmitTrainingJobArgs)modelType(ModelType)tags([]string)
Returns:
For more information, see the Go SDK Docs.
Parameters:
organizationId(string) (required): The organization ID.datasetId(string) (required): The dataset ID.modelName(string) (required): The model name.modelVersion(string) (required): The model version.modelType(ModelType) (required): The model type.tags(string) (required): The tags.
Returns:
- (Promise
)
Example:
await mlTrainingClient.submitTrainingJob(
'<organization-id>',
'<dataset-id>',
'<your-model-name>',
'1.0.0',
ModelType.SINGLE_LABEL_CLASSIFICATION,
['tag1', 'tag2']
);
For more information, see the TypeScript SDK Docs.
SubmitCustomTrainingJob
Submit a training job from a custom training script. Follow the guide to Train a Model with a Custom Python Training Script.
Parameters:
org_id(str) (required): the ID of the org to submit the training job to.dataset_id(str) (required): the ID of the dataset to train the model on.registry_item_id(str) (required): the ID of the training script from the registry.registry_item_version(str) (required): the version of the training script from the registry.model_name(str) (required): the model name.model_version(str) (required): the model version.
Returns:
- (str): : the ID of the training job.
Example:
job_id = await ml_training_client.submit_custom_training_job(
org_id="<organization-id>",
dataset_id="<dataset-id>",
registry_item_id="viam:classification-tflite",
registry_item_version="2024-08-13T12-11-54",
model_name="<your-model-name>",
model_version="1"
)
For more information, see the Python SDK Docs.
Parameters:
ctx(Context): A Context carries a deadline, a cancellation signal, and other values across API boundaries.args(SubmitTrainingJobArgs)registryItemIDregistryItemVersion(string)arguments(map[string]string)
Returns:
For more information, see the Go SDK Docs.
Parameters:
organizationId(string) (required): The organization ID.datasetId(string) (required): The dataset ID.registryItemId(string) (required): The registry item ID.registryItemVersion(string) (required): The registry item version.modelName(string) (required): The model name.modelVersion(string) (required): The model version.
Returns:
- (Promise
)
Example:
await mlTrainingClient.submitCustomTrainingJob(
'<organization-id>',
'<dataset-id>',
'viam:classification-tflite',
'1.0.0',
'<your-model-name>',
'1.0.0'
);
For more information, see the TypeScript SDK Docs.
GetTrainingJob
Get training job metadata.
Parameters:
id(str) (required): the ID of the requested training job.
Returns:
- (viam.proto.app.mltraining.TrainingJobMetadata): : the training job data.
Example:
job_metadata = await ml_training_client.get_training_job(
id="<job-id>")
For more information, see the Python SDK Docs.
Parameters:
ctx(Context): A Context carries a deadline, a cancellation signal, and other values across API boundaries.id(string)
Returns:
- (*TrainingJobMetadata)
- (error): An error, if one occurred.
For more information, see the Go SDK Docs.
Parameters:
id(string) (required): The training job ID.
Returns:
- (Promise<undefined | TrainingJobMetadata>)
Example:
const job = await mlTrainingClient.getTrainingJob('<training-job-id>');
For more information, see the TypeScript SDK Docs.
ListTrainingJobs
Get training job metadata for all jobs within an organization.
Parameters:
org_id(str) (required): the ID of the org to request training job data from.training_status(viam.proto.app.mltraining.TrainingStatus.ValueType) (optional): the status to filter the training jobs list by. If unspecified, all training jobs will be returned.
Returns:
- (List[viam.proto.app.mltraining.TrainingJobMetadata]): : the list of training job data.
Example:
jobs_metadata = await ml_training_client.list_training_jobs(
org_id="<org-id>")
first_job_id = jobs_metadata[1].id
For more information, see the Python SDK Docs.
Parameters:
ctx(Context): A Context carries a deadline, a cancellation signal, and other values across API boundaries.organizationID(string)status(TrainingStatus)
Returns:
- ([]*TrainingJobMetadata)
- (error): An error, if one occurred.
For more information, see the Go SDK Docs.
Parameters:
organizationId(string) (required): The organization ID.status(TrainingStatus) (required): The training job status.
Returns:
- (Promise<TrainingJobMetadata[]>)
Example:
const jobs = await mlTrainingClient.listTrainingJobs(
'<organization-id>',
TrainingStatus.RUNNING
);
For more information, see the TypeScript SDK Docs.
CancelTrainingJob
Cancel the specified training job.
Parameters:
id(str) (required): ID of the training job you wish to get metadata from. Retrieve this value withListTrainingJobs().
Returns:
- None.
Raises:
- (GRPCError): if no training job exists with the given ID.
Example:
await ml_training_client.cancel_training_job(
id="<job-id>")
For more information, see the Python SDK Docs.
Parameters:
ctx(Context): A Context carries a deadline, a cancellation signal, and other values across API boundaries.id(string)
Returns:
- (error): An error, if one occurred.
For more information, see the Go SDK Docs.
Parameters:
id(string) (required): The training job ID.
Returns:
- (Promise
)
Example:
await mlTrainingClient.cancelTrainingJob('<training-job-id>');
For more information, see the TypeScript SDK Docs.
DeleteCompletedTrainingJob
Delete a completed training job from the database, whether the job succeeded or failed.
Parameters:
id(str) (required): the ID of the training job to delete.
Returns:
- None.
Example:
await ml_training_client.delete_completed_training_job(
id="<job-id>")
For more information, see the Python SDK Docs.
Parameters:
ctx(Context): A Context carries a deadline, a cancellation signal, and other values across API boundaries.id(string)
Returns:
- (error): An error, if one occurred.
For more information, see the Go SDK Docs.
Parameters:
id(string) (required): The training job ID.
Returns:
- (Promise
)
Example:
await mlTrainingClient.deleteCompletedTrainingJob('<training-job-id>');
For more information, see the TypeScript SDK Docs.
Was this page helpful?
Glad to hear it! If you have any other feedback please let us know:
We're sorry about that. To help us improve, please tell us what we can do better:
Thank you!