Skip to content

Commit e06d3cf

Browse files
authored
feat(batch predition): Adds get_batch_prediction_job.py code sample (#13512)
* Adds get_batch_prediction_job.py code sample * Adds test_batch prediction_job. * Fixes lint and formatting errors. * Fixes errors in test file * Updates method definition. * Fix lint errors
1 parent ddb30bd commit e06d3cf

File tree

2 files changed

+64
-3
lines changed

2 files changed

+64
-3
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.genai import types
16+
17+
18+
def get_batch_job(batch_job_name: str) -> types.BatchJob:
19+
# [START googlegenaisdk_batch_job_get]
20+
from google import genai
21+
from google.genai.types import HttpOptions
22+
23+
client = genai.Client(http_options=HttpOptions(api_version="v1"))
24+
25+
# Get the batch job
26+
# Eg. batch_job_name = "projects/123456789012/locations/us-central1/batchPredictionJobs/1234567890123456789"
27+
batch_job = client.batches.get(name=batch_job_name)
28+
29+
print(f"Job state: {batch_job.state}")
30+
# Example response:
31+
# Job state: JOB_STATE_PENDING
32+
# Job state: JOB_STATE_RUNNING
33+
# Job state: JOB_STATE_SUCCEEDED
34+
35+
# [END googlegenaisdk_batch_job_get]
36+
return batch_job
37+
38+
39+
if __name__ == "__main__":
40+
try:
41+
get_batch_job(input("Batch job name: "))
42+
except Exception as e:
43+
print(f"An error occurred: {e}")

genai/batch_prediction/test_batch_prediction_examples.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,20 @@
1515
#
1616
# Using Google Cloud Vertex AI to test the code samples.
1717
#
18-
1918
from datetime import datetime as dt
20-
2119
import os
2220

21+
from unittest.mock import MagicMock, patch
22+
2323
from google.cloud import bigquery, storage
24+
from google.genai import types
2425
from google.genai.types import JobState
25-
2626
import pytest
2727

2828
import batchpredict_embeddings_with_gcs
2929
import batchpredict_with_bq
3030
import batchpredict_with_gcs
31+
import get_batch_job
3132

3233

3334
os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "True"
@@ -77,3 +78,20 @@ def test_batch_prediction_with_bq(bq_output_uri: str) -> None:
7778
def test_batch_prediction_with_gcs(gcs_output_uri: str) -> None:
7879
response = batchpredict_with_gcs.generate_content(output_uri=gcs_output_uri)
7980
assert response == JobState.JOB_STATE_SUCCEEDED
81+
82+
83+
@patch("google.genai.Client")
84+
def test_get_batch_job(mock_genai_client: MagicMock) -> None:
85+
# Mock the API response
86+
mock_batch_job = types.BatchJob(
87+
name="test-batch-job",
88+
state="JOB_STATE_PENDING"
89+
)
90+
91+
mock_genai_client.return_value.batches.get.return_value = mock_batch_job
92+
93+
response = get_batch_job.get_batch_job("test-batch-job")
94+
95+
mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1"))
96+
mock_genai_client.return_value.batches.get.assert_called_once()
97+
assert response == mock_batch_job

0 commit comments

Comments
 (0)