Source code for openghg.dataobjects._searchresults

import logging
from typing import Any, TypeVar
from collections.abc import Iterable

from openghg.dataobjects import ObsData
from pandas import DataFrame

__all__ = ["SearchResults"]

T = TypeVar("T", bound="SearchResults")

logger = logging.getLogger("openghg.dataobjects")
logger.setLevel(logging.DEBUG)  # Have to set level for logger as well as handler


[docs] class SearchResults: """This class is used to return data from the search function. It has member functions to retrieve data from the object store. Args: keys: Dictionary of keys keyed by Datasource UUID metadata: Dictionary of metadata keyed by Datasource UUID start_result: ? """ # TODO - WIP move to tinydb metadata lookup to simplify code
[docs] def __init__( self, metadata: dict | None = None, start_result: str | None = None, start_date: str | None = None, end_date: str | None = None, ): # db = tinydb.TinyDB(tinydb.storages.MemoryStorage) if metadata is not None: self.metadata = metadata # db.insert_multiple([m for m in metadata.values()]) self.results = ( DataFrame.from_dict(data=metadata, orient="index").reset_index().drop(columns="index") ) if start_result is not None: for uuid_key, uuid_metadata in metadata.items(): if start_result in uuid_metadata: other_keys = list(uuid_metadata.keys()) other_keys.remove(start_result) reorder = [start_result] + other_keys metadata[uuid_key] = {key: uuid_metadata[key] for key in reorder} else: self.results = {} # type: ignore self.metadata = {} self._start_date = start_date self._end_date = end_date
[docs] def __str__(self) -> str: SearchResults.df_to_table_console_output(df=DataFrame.from_dict(data=self.metadata)) return f"Found {len(self.results)} results.\nView the results DataFrame using the results property."
[docs] def __repr__(self) -> str: return self.__str__()
def __bool__(self) -> bool: return bool(self.metadata) def __len__(self) -> int: return len(self.metadata) # def __iter__(self) -> Iterator: # yield from self.results.iterrows()
[docs] def retrieve( self, dataframe: DataFrame | None = None, version: str = "latest", sort: bool = True, **kwargs: Any, ) -> ObsData | list[ObsData]: """Retrieve data from object store using a filtered pandas DataFrame Args: dataframe: pandas DataFrame version: Version of data requested from Datasource. Default = "latest". sort: Sort data by time in retrieved Dataset **kwargs: Metadata values to search for Returns: ObsData / List[ObsData]: ObsData object(s) """ if dataframe is not None: uuids = dataframe["uuid"].to_list() return self._retrieve_by_uuid(uuids=uuids, version=version, sort=sort) else: return self._retrieve_by_term(version=version, sort=sort, **kwargs)
[docs] def retrieve_all( self, version: str = "latest", sort: bool = True, ) -> ObsData | list[ObsData]: """Retrieves all data found during the search Args: version: Version of data requested from Datasource. Default = "latest". sort: Sort by time. Note that this may be very memory hungry for large Datasets. Returns: ObsData / List[ObsData]: ObsData object(s) """ return self._retrieve_by_uuid(uuids=self.metadata.keys(), version=version, sort=sort)
[docs] def uuids(self) -> list: """Return the UUIDs of the found data Returns: list: List of UUIDs """ return list(self.metadata.keys())
def _retrieve_by_term(self, version: str, sort: bool = True, **kwargs: Any) -> ObsData | list[ObsData]: """Retrieve data from the object store by search term. This function scans the metadata of the retrieved results, retrieves the UUID associated with that data, pulls it from the object store, recombines it into an xarray Dataset and returns ObsData object(s). Args: version: Version of data requested from Datasource. Default = "latest". sort: Sort by time. Note that this may be very memory hungry for large Datasets. **kwargs: Metadata values to search for """ uuids = set() # Make sure we don't have any Nones clean_kwargs = {k: v for k, v in kwargs.items() if v is not None} for uid, metadata in self.metadata.items(): n_required = len(clean_kwargs) n_matched = 0 for key, value in clean_kwargs.items(): try: # Here we want to check if it's a list and if so iterate over it if isinstance(value, (list, tuple)): for val in value: val = str(val).lower() if metadata[key.lower()] == val: n_matched += 1 break else: value = str(value).lower() if metadata[key.lower()] == value: n_matched += 1 except KeyError: pass if n_matched == n_required: uuids.add(uid) # Now we can retrieve the data using the UUIDs return self._retrieve_by_uuid(uuids=list(uuids), version=version, sort=sort) def _retrieve_by_uuid(self, uuids: Iterable, version: str, sort: bool = True) -> ObsData | list[ObsData]: """Internal retrieval function that uses the passed in UUIDs to retrieve the keys from the key_data dictionary, pull the data from the object store, create ObsData object(s) and return the result. Args: uuids: UUIDs of Datasources in the object store version: Version of data requested from Datasource. Default = "latest". sort: Sort by time. Note that this may be very memory hungry for large Datasets. Returns: ObsData / List[ObsData]: ObsData object(s) """ results = [] for uuid in uuids: metadata = self.metadata[uuid] if version == "latest": version = metadata["latest_version"] else: if version not in metadata["versions"]: raise ValueError(f"Invalid version {version} for UUID {uuid}") results.append( ObsData( uuid=uuid, version=version, metadata=metadata, start_date=self._start_date, end_date=self._end_date, sort=sort, ) ) if len(results) == 1: return results[0] else: return results
[docs] @staticmethod def df_to_table_console_output(df: DataFrame) -> None: """ Process the DataFrame and display it as a formatted table in the console. Args: df (DataFrame): The DataFrame to be processed and displayed. Returns: None """ try: from rich import print from rich.table import Table, box except ModuleNotFoundError: logger.warning("Unable to use rich package to display search results. Please install rich") return None # Split columns into sets column_sets = [df.columns[i : i + 4] for i in range(0, len(df.columns), 4)] # Iterate over the column sets for columns in column_sets: # Create a table instance table = Table(show_header=False, header_style="bold", box=box.HORIZONTALS) # Add table headers for i, column in enumerate(columns, start=1): if i == 1: table.add_column(column, style="bold") else: table.add_column(column) # Add table data for index, row in df.iterrows(): row_data = [str(index)] + [str(row[column]) for column in columns] table.add_row(*row_data) # Print the table print(table) print()