Source code for plate_model_manager.plate_model_manager

import json
import logging
import os
import re
from typing import Dict, Union

import requests

from .utils.enums import ReferenceFrame

from .exceptions import InvalidConfigFile, ServerUnavailable
from .plate_model import PlateModel

logger = logging.getLogger("pmm")


[docs] class PlateModelManager: """Manage a set of publicly available plate reconstruction models. The model files are hosted on EarthByte servers. You need Internet connection to use this class and download the files. """ # Load a models.json file and manage plate models. # See an example models.json file at PlateModelManager.get_default_repo_url().
[docs] def __init__(self, model_manifest: str = "", timeout=(None, None)): """Constructor. Create a :class:`PlateModelManager` instance. You need Internet connection to create an instance of this class. If you don't have Internet connection, use :class:`PlateModel` class directly in ``readonly`` mode. Visit `this page <examples.html#use-without-internet>`__ to see an example. :param model_manifest: The URL to a ``models.json`` metadata file. Normally you don't need to provide this parameter unless you would like to setup your own plate model server. """ if not model_manifest: self.model_manifest = PlateModelManager.get_default_repo_url() else: self.model_manifest = model_manifest self._models = None self.timeout = timeout if not isinstance(self.model_manifest, str): raise InvalidConfigFile( f"The model_manifest '{type(self.model_manifest)}' must be a string. It is either a local file path or a http(s) URL." ) # check if the model manifest file is a local file if os.path.isfile(self.model_manifest): with open(self.model_manifest) as f: self._models = json.load(f) elif self.model_manifest.startswith( "http://" ) or self.model_manifest.startswith("https://"): # try the http(s) url try: r = requests.get(self.model_manifest, timeout=timeout) if r.status_code != 200: raise InvalidConfigFile( f"Unable to get valid JSON data from '{self.model_manifest}'. Http request return code: {r.status_code}" ) else: self._models = r.json() except ( requests.exceptions.ConnectionError, requests.exceptions.ConnectTimeout, requests.exceptions.ReadTimeout, ): raise ServerUnavailable( f"Unable to fetch {self.model_manifest}. No network connection, server unavailable or invalid URL!" ) except requests.exceptions.JSONDecodeError: raise InvalidConfigFile( f"Unable to get valid JSON data from '{self.model_manifest}'." ) else: raise InvalidConfigFile( f"The model_manifest '{self.model_manifest}' must be either a local file path or a http(s) URL." ) if "vars" in self.models: self._replace_vars_with_values(self.models["vars"], self.models) del self.models["vars"]
@property def models(self) -> Dict: """The metadata for all the models.""" if self._models is not None: return self._models else: raise Exception( f"No model found. Check the model manifest {self.model_manifest} for errors." ) @models.setter def models(self, var) -> None: self._models = var def _replace_vars_with_values(self, var_dict, json_obj): """Replace the variables in `json_obj` with the real values. The variables are defined in `var_dict`.""" for key, value in json_obj.items(): if key == "vars": continue if isinstance(value, dict): self._replace_vars_with_values(var_dict, value) elif isinstance(value, str): matches = re.findall("@<<(.*)>>@", value) for m in matches: if m in var_dict: value = value.replace(f"@<<{m}>>@", var_dict[m]) json_obj[key] = value else: continue def _resolve_model_config( self, model_name: str, data_dir: str, visited: set = None, max_depth: int = 10, ) -> Union[dict, None]: """Resolve model configuration, handling alias chains with recursion protection. :param model_name: The model name (case-insensitive) :param data_dir: The folder to save model files :param visited: Set of already visited model names to detect circular aliases :param max_depth: Maximum recursion depth to prevent infinite loops :returns: The resolved model configuration dict or None if not found :raises InvalidConfigFile: If circular alias or max depth exceeded """ if visited is None: visited = set() if len(visited) >= max_depth: raise InvalidConfigFile( f"Maximum alias resolution depth ({max_depth}) exceeded. " f"Possible circular alias in model manifest {self.model_manifest}. " f"Resolution chain: {' -> '.join(visited)} -> {model_name}" ) model_name = model_name.lower() if model_name not in self.models: return None model_entry = self.models[model_name] visited_copy = visited.copy() visited_copy.add(model_name) # If entry is a string, it's an alias reference if isinstance(model_entry, str): # Remove optional '@' prefix that marks aliases target_model_name = ( model_entry[1:] if model_entry.startswith("@") else model_entry ) # Recursively resolve the target return self._resolve_model_config( target_model_name, data_dir, visited_copy, max_depth ) # Entry is a dict, return it as the configuration return model_entry if isinstance(model_entry, dict) else None
[docs] def get_model( self, model_name: str = "default", data_dir: str = ".", reference_frame: Union[ReferenceFrame, None] = None, ) -> Union[PlateModel, None]: """Retrieve a :class:`PlateModel` object for a given plate model name. This method resolves model aliases and creates a PlateModel instance configured with the model's metadata. Alias resolution follows chains and detects circular references to prevent infinite loops. Call :meth:`get_available_model_names()` to see a list of available model names and valid aliases. :param model_name: The name of the plate model to retrieve. Case-insensitive. Can be a direct model name, an alias, or a variant with reference frame suffix (e.g., "model_pmag_ref"). Defaults to "default". :param data_dir: The folder path to save downloaded plate model files. Defaults to the current directory ("."); This path can be changed later with :meth:`PlateModel.set_data_dir()`. :param reference_frame: Optional reference frame for the plate model. If set to :attr:`ReferenceFrame.PmagReferenceFrame` and a "_pmag_ref" variant exists, that variant will be loaded automatically. :returns: A :class:`PlateModel` object if the model is found and successfully created, ``None`` if the model name is not found in the manifest. :raises InvalidConfigFile: If a circular alias chain is detected or if the maximum alias resolution depth is exceeded, indicating an error in the model manifest. :example: >>> pmm = PlateModelManager() >>> model = pmm.get_model("muller2016", data_dir="./models") >>> if model: ... model.download_all_layers() """ model_name_lower = model_name.lower() if reference_frame == ReferenceFrame.PmagReferenceFrame: if f"{model_name_lower}_pmag_ref" in self.models: model_name_lower += "_pmag_ref" try: model_cfg = self._resolve_model_config(model_name_lower, data_dir) except InvalidConfigFile: raise if model_cfg is None: logger.error(f"Model '{model_name}' is not available.") return None if ( reference_frame == ReferenceFrame.PmagReferenceFrame and not model_name_lower.endswith("_pmag_ref") ): if ( model_cfg.get("Attributes", {}).get("PmagReferenceFrameAnchorPID") is None ): logger.error( f"Model '{model_name}' does not have a PMAG reference frame version available." ) return None return PlateModel( model_name_lower, model_cfg=model_cfg, data_dir=data_dir, reference_frame=reference_frame, )
[docs] def get_available_model_names(self): """Return the names of available models as a list.""" return list(self.models.keys())
[docs] @staticmethod def get_local_available_model_names(local_dir: str): """Return a list of model names in a local folder. :param local_dir: The local folder containing models. :type local_dir: str """ models = [] for file in os.listdir(local_dir): d = os.path.join(local_dir, file) if os.path.isdir(d) and os.path.isfile(f"{d}/.metadata.json"): models.append(file) return models
[docs] @staticmethod def get_default_repo_url(): """Return the URL to the configuration data of models.""" default_repo_url_list = [ "https://repo.gplates.org/webdav/pmm/config/models_v2.json", "https://www.earthbyte.org/webdav/pmm/config/models_v2_eb.json", "https://portal.gplates.org/static/pmm/config/models_v2_gp.json", ] for url in default_repo_url_list: try: response = requests.head(url, timeout=(5, 5)) if response.status_code == 200: return url else: logger.warning( f"Unable to fetch {url}. status_code={response.status_code}" ) continue except: logger.warning(f"Unable to fetch {url}.") continue raise ServerUnavailable( """Cannot connect to the servers. Either the servers are currently unavailable, or there is a problem with your internet connection.""" )
[docs] def download_all_models(self, data_dir: str = "./") -> None: """Download all available models into the ``data_dir``. :param data_dir: The folder to save the model files. :type data_dir: str """ for name in self.get_available_model_names(): print(f"download {name}") model = self.get_model(name) if model is not None: model.set_data_dir(data_dir) model.download_all_layers()