import inspect
from typing import Any, Dict
from mantlebio.core.analysis.helpers import (
unmarshall_analysis_proto,
validate_analysis_value
)
from mantlebio.core.dataset.client import _IDatasetClient
from mantlebio.core.dataset.mantle_dataset import _IDataset, MantleDataset
from mantlebio.core.session.mantle_session import _ISession
from mantlebio.core.storage.client import _IStorageClient
from mantlebio.exceptions import MantleAttributeError, MantleInvalidParameterError
from mantlebio.helpers.decorators import deprecated
from proto import analysis_pb2, data_type_pb2, entity_pb2
from pathlib import Path
[docs]
class MantleAnalysis:
"""Wrapper for Analysis proto object with functions that extend the functionality of the proto object"""
def __init__(self, analysis: analysis_pb2.Analysis, session: _ISession, storage_client: _IStorageClient, dataset_client: _IDatasetClient) -> None:
'''
Args:
analysis (analysis_pb2.Analysis): The Analysis proto object
session (_ISession): MantleSession object for making calls to the mantle api
storage_client (_IStorageClient): Storage client for uploading files
dataset_client (_IDatasetClient): Dataset client for interacting with datasets
'''
self._session = session
self._storage_client = storage_client
self._name = ""
self.set_inputs(analysis.inputs)
self.set_outputs(analysis.outputs)
self._dataset_client = dataset_client
self._analysis_instance = analysis # the proto object
[docs]
def _wrap_method(self, method):
def wrapper(*args, **kwargs):
return method(*args, **kwargs)
return wrapper
def __getattr__(self, name):
# First, check if the object itself has the property
# TODO: this should be removed when we deprecate the proto property accessors
try:
return super().__getattribute__(name)
except AttributeError:
pass
# Dynamically route attribute access to the protobuf object
if hasattr(self._analysis_instance, name):
attr = getattr(self._analysis_instance, name)
if inspect.ismethod(attr):
return self._wrap_method(attr)
return attr
raise MantleAttributeError(f"'{type(self._analysis_instance).__name__}' object has no attribute '{name}'")
[docs]
def set_name(self, name: str):
"""
Set the name of the MantleAnalysis object.
Args:
name (str): The name of the MantleAnalysis object.
"""
self._name = name
[docs]
def set_outputs(self, outputs: analysis_pb2.AnalysisOutput):
"""
Set the outputs of the MantleAnalysis object.
Args:
outputs (analysis_pb2.AnalysisOutput): The outputs of the MantleAnalysis object.
"""
self._outputs = outputs.data
[docs]
def add_file_output(self, output_key: str, local_path: str):
"""
Add a file output to the MantleAnalysis object.
Args:
output_key (str): The key of the output.
local_path (str): The local path of the file.
"""
new_s3_file_pb2 = data_type_pb2.FileUpload(
filename=local_path
)
self.add_output(output_key, new_s3_file_pb2)
[docs]
def get_output(self, key: str):
"""
Get the value of an output.
Args:
key (str): The key of the output.
Returns:
Any: The value of the output.
"""
output = self._outputs[key]
value_type = output.WhichOneof('value')
if value_type == "entity":
return MantleDataset(output.entity, self._session, self._dataset_client)
return output.__getattribute__(value_type)
[docs]
def add_output(self, output_key: str, val: Any, force: bool = False):
"""
Add an output to the MantleAnalysis object.
Args:
output_key (str): The key of the output.
val (Any): The value of the output.
"""
if not force and self._outputs.get(output_key):
Warning(f"Output '{output_key}' already exists. Use force=True to overwrite.")
return
if isinstance(val, _IDataset):
val = val.to_proto()
analysis_value_args = validate_analysis_value(val)
self._outputs.get_or_create(output_key).CopyFrom(analysis_pb2.AnalysisValue(
**analysis_value_args))
self.push_output(output_key)
[docs]
@deprecated("2.0.0", "Use get_output instead.")
def pull_entity(self, key: str, entity_id: str) -> _IDataset:
"""
Pull a dataset from the analysis input.
Args:
key (str): The key of the input.
entity_id (str): The ID of the dataset.
Returns:
_IDataset: The pulled dataset.
"""
e = self._dataset_client.get(entity_id)
self.add_input(key, e)
return e
[docs]
def push_output(self, key: str):
"""
Push a value to the analysis output.
Args:
key (str): The key of the output.
"""
if not self._outputs.get(key):
raise KeyError(f"Output '{key}' not found.")
field_label = self._outputs[key].WhichOneof('value')
if field_label == "file_upload":
file_upload_pb2_obj: data_type_pb2.FileUpload = self.get_output(key)
if not file_upload_pb2_obj.filename:
raise MantleInvalidParameterError("FileUpload object must have a filename attribute")
if field_label == "entity":
entity_pb2_obj: entity_pb2.Entity = self.get_output(key).to_proto()
self._outputs[key].CopyFrom(analysis_pb2.AnalysisValue(entity=entity_pb2_obj))
# update the analysis record with a new output
analysis_resp = self._session.make_request(
"POST", f"/analysis/{self.unique_id}/output/{key}",
data=self._outputs[key]
)
if not analysis_resp.ok:
analysis_resp.raise_for_status()
analysis_pb2_obj = unmarshall_analysis_proto(
proto_content=analysis_resp.content)
self._analysis_instance = analysis_pb2_obj
if field_label == "file_upload":
s3_key = analysis_pb2_obj.outputs.data[key].s3_file.key
new_s3_file_pb2 = self._storage_client.upload_file(
path=file_upload_pb2_obj.filename, upload_key=s3_key)
self._outputs[key].CopyFrom(analysis_pb2.AnalysisValue(s3_file=new_s3_file_pb2))
self._analysis_instance.outputs.data[key].s3_file.CopyFrom(new_s3_file_pb2
)