diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index bcc4b9390..6a7a6475c 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -386,9 +386,6 @@ def _init_model_source(data): gcs_tflite_uri = data.pop('gcsTfliteUri', None) if gcs_tflite_uri: return TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri) - auto_ml_model = data.pop('automlModel', None) - if auto_ml_model: - return TFLiteAutoMlSource(auto_ml_model=auto_ml_model) return None @property @@ -603,36 +600,6 @@ def as_dict(self, for_upload=False): return {'gcsTfliteUri': self._gcs_tflite_uri} -class TFLiteAutoMlSource(TFLiteModelSource): - """TFLite model source representing a tflite model created with AutoML.""" - - def __init__(self, auto_ml_model, app=None): - self._app = app - self.auto_ml_model = auto_ml_model - - def __eq__(self, other): - if isinstance(other, self.__class__): - return self.auto_ml_model == other.auto_ml_model - return False - - def __ne__(self, other): - return not self.__eq__(other) - - @property - def auto_ml_model(self): - """Resource name of the model, created by the AutoML API or Cloud console.""" - return self._auto_ml_model - - @auto_ml_model.setter - def auto_ml_model(self, auto_ml_model): - self._auto_ml_model = _validate_auto_ml_model(auto_ml_model) - - def as_dict(self, for_upload=False): - """Returns a serializable representation of the object.""" - # Upload is irrelevant for auto_ml models - return {'automlModel': self._auto_ml_model} - - class ListModelsPage: """Represents a page of models in a Firebase project. diff --git a/integration/test_ml.py b/integration/test_ml.py index 52cb1bb7e..5462924bb 100644 --- a/integration/test_ml.py +++ b/integration/test_ml.py @@ -22,7 +22,6 @@ import pytest -import firebase_admin from firebase_admin import exceptions from firebase_admin import ml from tests import testutils @@ -35,12 +34,6 @@ except ImportError: _TF_ENABLED = False -try: - from google.cloud import automl_v1 - _AUTOML_ENABLED = True -except ImportError: - _AUTOML_ENABLED = False - def _random_identifier(prefix): #pylint: disable=unused-variable suffix = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(8)]) @@ -159,14 +152,6 @@ def check_tflite_gcs_format(model, validation_error=None): assert model.model_hash is not None -def check_tflite_automl_format(model): - assert model.validation_error is None - assert model.published is False - assert model.model_format.model_source.auto_ml_model.startswith('projects/') - # Automl models don't have validation errors since they are references - # to valid automl models. - - @pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) def test_create_simple_model(firebase_model): check_model(firebase_model, NAME_AND_TAGS_ARGS) @@ -388,50 +373,3 @@ def test_from_saved_model(saved_model_dir): assert created_model.validation_error is None finally: _clean_up_model(created_model) - - -# Test AutoML functionality if AutoML is enabled. -#'pip install google-cloud-automl' in the environment if you want _AUTOML_ENABLED = True -# You will also need a predefined AutoML model named 'admin_sdk_integ_test1' to run the -# successful test. (Test is skipped otherwise) - -@pytest.fixture -def automl_model(): - assert _AUTOML_ENABLED - - # It takes > 20 minutes to train a model, so we expect a predefined AutoMl - # model named 'admin_sdk_integ_test1' to exist in the project, or we skip - # the test. - automl_client = automl_v1.AutoMlClient() - project_id = firebase_admin.get_app().project_id - parent = automl_client.location_path(project_id, 'us-central1') - models = automl_client.list_models(parent, filter_="display_name=admin_sdk_integ_test1") - # Expecting exactly one. (Ok to use last one if somehow more than 1) - automl_ref = None - for model in models: - automl_ref = model.name - - # Skip if no pre-defined model. (It takes min > 20 minutes to train a model) - if automl_ref is None: - pytest.skip("No pre-existing AutoML model found. Skipping test") - - source = ml.TFLiteAutoMlSource(automl_ref) - tflite_format = ml.TFLiteFormat(model_source=source) - ml_model = ml.Model( - display_name=_random_identifier('TestModel_automl_'), - tags=['test_automl'], - model_format=tflite_format) - model = ml.create_model(model=ml_model) - yield model - _clean_up_model(model) - -@pytest.mark.skipif(not _AUTOML_ENABLED, reason='AutoML is required for this test.') -def test_automl_model(automl_model): - # This test looks for a predefined automl model with display_name = 'admin_sdk_integ_test1' - automl_model.wait_for_unlocked() - - check_model(automl_model, { - 'display_name': automl_model.display_name, - 'tags': ['test_automl'], - }) - check_tflite_automl_format(automl_model) diff --git a/tests/test_ml.py b/tests/test_ml.py index abd6d06f9..2c652ffa0 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -122,18 +122,6 @@ } TFLITE_FORMAT_2 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) -AUTOML_MODEL_NAME = 'projects/111111111111/locations/us-central1/models/ICN7683346839371803263' -AUTOML_MODEL_SOURCE = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME) -TFLITE_FORMAT_JSON_3 = { - 'automlModel': AUTOML_MODEL_NAME, - 'sizeBytes': '3456789' -} -TFLITE_FORMAT_3 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_3) - -AUTOML_MODEL_NAME_2 = 'projects/2222222222/locations/us-central1/models/ICN2222222222222222222' -AUTOML_MODEL_NAME_JSON_2 = {'automlModel': AUTOML_MODEL_NAME_2} -AUTOML_MODEL_SOURCE_2 = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME_2) - CREATED_UPDATED_MODEL_JSON_1 = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, @@ -417,14 +405,6 @@ def test_model_keyword_based_creation_and_setters(self): 'tfliteModel': TFLITE_FORMAT_JSON_2 } - model.model_format = TFLITE_FORMAT_3 - assert model.as_dict() == { - 'displayName': DISPLAY_NAME_2, - 'tags': TAGS_2, - 'tfliteModel': TFLITE_FORMAT_JSON_3 - } - - def test_gcs_tflite_model_format_source_creation(self): model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) model_format = ml.TFLiteFormat(model_source=model_source) @@ -436,17 +416,6 @@ def test_gcs_tflite_model_format_source_creation(self): } } - def test_auto_ml_tflite_model_format_source_creation(self): - model_source = ml.TFLiteAutoMlSource(auto_ml_model=AUTOML_MODEL_NAME) - model_format = ml.TFLiteFormat(model_source=model_source) - model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) - assert model.as_dict() == { - 'displayName': DISPLAY_NAME_1, - 'tfliteModel': { - 'automlModel': AUTOML_MODEL_NAME - } - } - def test_source_creation_from_tflite_file(self): model_source = ml.TFLiteGCSModelSource.from_tflite_model_file( "my_model.tflite", "my_bucket") @@ -460,13 +429,6 @@ def test_gcs_tflite_model_source_setters(self): assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2 assert model_source.as_dict() == GCS_TFLITE_URI_JSON_2 - def test_auto_ml_tflite_model_source_setters(self): - model_source = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME) - model_source.auto_ml_model = AUTOML_MODEL_NAME_2 - assert model_source.auto_ml_model == AUTOML_MODEL_NAME_2 - assert model_source.as_dict() == AUTOML_MODEL_NAME_JSON_2 - - def test_model_format_setters(self): model_format = ml.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2 @@ -477,14 +439,6 @@ def test_model_format_setters(self): } } - model_format.model_source = AUTOML_MODEL_SOURCE - assert model_format.model_source == AUTOML_MODEL_SOURCE - assert model_format.as_dict() == { - 'tfliteModel': { - 'automlModel': AUTOML_MODEL_NAME - } - } - def test_model_as_dict_for_upload(self): model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) model_format = ml.TFLiteFormat(model_source=model_source) @@ -570,23 +524,6 @@ def test_gcs_tflite_source_validation_errors(self, uri, exc_type): ml.TFLiteGCSModelSource(gcs_tflite_uri=uri) check_error(excinfo, exc_type) - @pytest.mark.parametrize('auto_ml_model, exc_type', [ - (123, TypeError), - ('abc', ValueError), - ('/projects/123456/locations/us-central1/models/noLeadingSlash', ValueError), - ('projects/123546/models/ICN123456', ValueError), - ('projects//locations/us-central1/models/ICN123456', ValueError), - ('projects/123456/locations//models/ICN123456', ValueError), - ('projects/123456/locations/us-central1/models/', ValueError), - ('projects/ABC/locations/us-central1/models/ICN123456', ValueError), - ('projects/123456/locations/us-central1/models/@#$%^&', ValueError), - ('projects/123456/locations/us-cent/ral1/models/ICN123456', ValueError), - ]) - def test_auto_ml_tflite_source_validation_errors(self, auto_ml_model, exc_type): - with pytest.raises(exc_type) as excinfo: - ml.TFLiteAutoMlSource(auto_ml_model=auto_ml_model) - check_error(excinfo, exc_type) - def test_wait_for_unlocked_not_locked(self): model = ml.Model(display_name="not_locked") model.wait_for_unlocked()
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: