Source code for mantlebio.core.dataset.client


from abc import abstractmethod
from typing import Any, Dict, Iterable, Optional
import warnings
from mantlebio.core.dataset.helpers import DatasetPropertiesBuilder, unmarshall_dataset
from mantlebio.core.dataset.mantle_dataset import _IDataset, MantleDataset
from mantlebio.core.session.mantle_session import _ISession
from mantlebio.core.storage.client import _IStorageClient
from mantlebio.exceptions import MantleInvalidParameterError, MantleMissingParameterError, MantleProtoError
from mantlebio.helpers.decorators import deprecated
from mantlebio.types.query.query_builder import QueryBuilder
from mantlebio.types.query.queryable_client import QueryableClient
from mantlebio.types.response.list_reponse import ListResponse
from proto import entity_pb2
from google.protobuf.message import DecodeError

[docs] class _IDatasetClient(QueryableClient): """ Interface for an entity client that interacts with entities in a data storage system. """ @abstractmethod def __init__(self, session: _ISession, storage_client: _IStorageClient, analysis_id: str = "", pipeline_run_id: str = "") -> None: pass
[docs] @abstractmethod def get(self, id: str) -> _IDataset: """ Retrieves an entity by its ID. Args: id (str): The ID of the entity. Returns: _IEntity: The retrieved entity. """ raise NotImplementedError
[docs] @abstractmethod def get_entities(self) -> ListResponse[_IDataset]: """ Retrieves all entities. Returns: ListResponse[_IEntity]: An Iterable List of all entities. """ raise NotImplementedError
[docs] @abstractmethod def get_datasets(self) -> ListResponse[_IDataset]: """ Retrieves all datasets. Returns: ListResponse[_IDataset]: An Iterable List of all datasets. """ raise NotImplementedError
[docs] @abstractmethod def list(self) -> ListResponse[_IDataset]: """ Retrieves all datasets. Returns: ListResponse[_IDataset]: An Iterable List of all datasets. """ raise NotImplementedError
[docs] @abstractmethod def create(self, name: Optional[str] = None, dataset_type: Optional[str] = None, properties: Optional[Dict[str, Any]] = None, local: bool = True, origin: Optional[entity_pb2.Origin] = None, entity_type: Optional[str] = None) -> _IDataset: """ Creates a new dataset. Args: dataset_type (str, optional): The type of the dataset. Defaults to None. properties (Dict[str,Any], optional): The properties of the dataset. Defaults to None. local (bool, optional): Whether the dataset is local. Defaults to True. origin (entity_pb2.Origin, optional): The origin of the dataset. Defaults to None. entity_type (str, optional): The type of the entity. Defaults to None. This is a deprecated parameter. Will be removed in 2.0.0. Returns: _IDataset: The created dataset. """ raise NotImplementedError
[docs] class DatasetClient(_IDatasetClient): """DatasetClient object for making requests to the Mantle API""" def __init__(self, session: _ISession, storage_client: _IStorageClient, analysis_id: str = "", pipeline_run_id: str = "") -> None: self._session = session self._storage_client = storage_client self._route_stem = f"/entity" self._analysis_id = analysis_id self._pipeline_run_id = pipeline_run_id
[docs] def get(self, id: str) -> _IDataset: """Get a dataset by ID Args: id (str): Dataset ID Returns: Dataset: Dataset object """ res = self._session.make_request("GET", f"{self._route_stem}/{id}") if not res.ok: res.raise_for_status() try: dataset = entity_pb2.EntityResponse() dataset.ParseFromString(res.content) except DecodeError as e: raise MantleProtoError(res.content, entity_pb2.EntityResponse) from e return MantleDataset(dataset_input=dataset.entity, session=self._session, storage_client=self._storage_client)
[docs] @deprecated("2.0.0", "use get() instead") def get_entity(self, id: str) -> _IDataset: """Get a dataset by ID Args: id (str): Dataset ID Returns: Dataset: Dataset object """ return self.get(id)
[docs] @deprecated("2.0.0", "use list() instead") def get_entities(self) -> Iterable[_IDataset]: """Get a list of Datasets Returns: list: List of Dataset objects """ return self.get_datasets()
[docs] @deprecated("2.0.0", "use list() instead") def get_datasets(self) -> Iterable[_IDataset]: """Get a list of Datasets Returns: list: List of Dataset objects """ return self.list()
[docs] def build_query(self) -> QueryBuilder: return QueryBuilder(self)
[docs] def list(self, query_params: Dict[str,str] = {}) -> ListResponse[_IDataset]: """Get a list of datasets Returns: ListResponse[Dataset]: List of Dataset objects """ # will need refactor for repeated objects datasets = ListResponse[_IDataset]() try: has_next_token = True while has_next_token: res = self._session.make_request("GET", self._route_stem, params=query_params) if not res.ok: res.raise_for_status() dataset_list = entity_pb2.EntitiesResponse() dataset_list.ParseFromString(res.content) for dataset in dataset_list.entities: datasets.append(MantleDataset( dataset_input=dataset, session=self._session, storage_client=self._storage_client)) has_next_token = dataset_list.next_page_token != "" query_params['page_token'] = dataset_list.next_page_token except DecodeError as e: raise MantleProtoError(res.content, entity_pb2.EntitiesResponse) from e return datasets
[docs] def _create_local_dataset(self, name: Optional[str] = None, dataset_type: Optional[str] = None, properties: Optional[Dict[str, Any]] = None, origin: Optional[entity_pb2.Origin] = None) -> _IDataset: dataset_params_json = {} if not properties: properties = {} if dataset_type: dataset_params_json.update( {"data_type": {"unique_id": dataset_type}}) if name: dataset_params_json.update({"name": name}) dataset_params_json.update({"props": properties}) if origin: dataset_params_json.update({"origin": origin}) return MantleDataset(dataset_input=dataset_params_json, session=self._session, storage_client=self._storage_client, local=True)
[docs] def _create_cloud_dataset(self, dataset_type: str, properties: Optional[Dict[str, Any]] = None, origin: Optional[entity_pb2.Origin] = None, name: Optional[str] = None) -> _IDataset: if not properties: properties = {} property_builder = DatasetPropertiesBuilder() dataset_props = property_builder.build_create_dataset_props( properties) dataset_req = entity_pb2.CreatEntityRequest( name=name, data_type_id=dataset_type, props=dataset_props, origin=origin ) res = self._session.make_request( "POST", self._route_stem, data=dataset_req) if not res.ok: res.raise_for_status() dataset_res = unmarshall_dataset(res.content) for key, val in dataset_props.items(): if val.WhichOneof('value') == 'file_upload': s3_file_upload_proto = dataset_res.props[key].s3_file if s3_file_upload_proto is None: raise MantleMissingParameterError( f"Property {key} is missing an S3 file.") upload_prefix = s3_file_upload_proto.key if not upload_prefix: raise MantleMissingParameterError( f"Property {key} is missing an S3 file key.") self._storage_client.upload_file( val.file_upload.filename, upload_prefix) return MantleDataset(dataset_input=dataset_res, session=self._session, storage_client=self._storage_client)
[docs] def create(self, name: Optional[str] = None, dataset_type: Optional[str] = None, properties: Optional[Dict[str, Any]] = None, local: bool = True, origin: Optional[entity_pb2.Origin] = None, entity_type: Optional[str] = None ) -> _IDataset: if entity_type: warnings.warn(f"entity_type parameter is deprecated and will be removed in version 2.0.0. Use dataset_type instead.", category=DeprecationWarning, stacklevel=2) dataset_type = entity_type if local: return self._create_local_dataset(name, dataset_type, properties, origin) else: if not dataset_type: raise MantleInvalidParameterError("dataset_type is required") return self._create_cloud_dataset(dataset_type, properties, origin, name=name)
[docs] @deprecated("2.0.0", "use create() instead.") def create_cloud_entity(self, entity_type: str, properties: Optional[Dict[str, Any]] = None) -> _IDataset: return self.create(dataset_type=entity_type, properties=properties, local=False)
[docs] @deprecated("2.0.0", "use create() instead.") def create_local_entity(self, entity_type: Optional[str] = None, properties: Optional[Dict[str, Any]] = None) -> _IDataset: return self.create(dataset_type=entity_type, properties=properties, local=True)
[docs] @deprecated("2.0.0", "use create() instead.") def create_empty_entity(self): return self.create_local_entity()
[docs] class EntityClient(DatasetClient): @deprecated("2.0.0", "Use DatasetClient instead") def __init__(self, session: _ISession, storage_client: _IStorageClient, analysis_id: str = "", pipeline_run_id: str = "") -> None: super().__init__(session, storage_client, analysis_id, pipeline_run_id)