import json
import logging
import os
import re
from typing import Dict, Union
import requests
from .utils.enums import ReferenceFrame
from .exceptions import InvalidConfigFile, ServerUnavailable
from .plate_model import PlateModel
logger = logging.getLogger("pmm")
[docs]
class PlateModelManager:
"""Manage discovery and loading of plate reconstruction model metadata.
Model manifests can be loaded from a local file or an HTTP(S) endpoint.
Retrieved model configurations are used to construct :class:`PlateModel`
instances.
"""
# Load a models.json file and manage plate models.
# See an example models.json file at PlateModelManager.get_default_repo_url().
[docs]
def __init__(self, model_manifest: str = "", timeout=(None, None)):
"""Create a :class:`PlateModelManager` instance.
If ``model_manifest`` is omitted, the manager probes known PMM manifest
endpoints and uses the first reachable URL.
:param model_manifest: Local path or HTTP(S) URL for a ``models.json``
manifest. Use this when hosting a custom model repository.
:param timeout: Timeout tuple passed to HTTP requests.
:raises InvalidConfigFile: If the manifest path/URL is invalid or does
not contain valid JSON.
:raises ServerUnavailable: If the manifest URL cannot be reached.
"""
if not model_manifest:
self.model_manifest = PlateModelManager.get_default_repo_url()
else:
self.model_manifest = model_manifest
self._models = None
self.timeout = timeout
if not isinstance(self.model_manifest, str):
raise InvalidConfigFile(
f"The model_manifest '{type(self.model_manifest)}' must be a string. It is either a local file path or a http(s) URL."
)
# check if the model manifest file is a local file
if os.path.isfile(self.model_manifest):
with open(self.model_manifest) as f:
self._models = json.load(f)
elif self.model_manifest.startswith(
"http://"
) or self.model_manifest.startswith("https://"):
# try the http(s) url
try:
r = requests.get(self.model_manifest, timeout=timeout)
if r.status_code != 200:
raise InvalidConfigFile(
f"Unable to get valid JSON data from '{self.model_manifest}'. Http request return code: {r.status_code}"
)
else:
self._models = r.json()
except (
requests.exceptions.ConnectionError,
requests.exceptions.ConnectTimeout,
requests.exceptions.ReadTimeout,
):
raise ServerUnavailable(
f"Unable to fetch {self.model_manifest}. No network connection, server unavailable or invalid URL!"
)
except requests.exceptions.JSONDecodeError:
raise InvalidConfigFile(
f"Unable to get valid JSON data from '{self.model_manifest}'."
)
else:
raise InvalidConfigFile(
f"The model_manifest '{self.model_manifest}' must be either a local file path or a http(s) URL."
)
if "vars" in self.models:
self._replace_vars_with_values(self.models["vars"], self.models)
del self.models["vars"]
@property
def models(self) -> Dict:
"""Return metadata for all configured models.
:returns: Mapping from model names to model entries.
:rtype: Dict
:raises Exception: If model metadata is unavailable.
"""
if self._models is not None:
return self._models
else:
raise Exception(
f"No model found. Check the model manifest {self.model_manifest} for errors."
)
@models.setter
def models(self, var) -> None:
self._models = var
def _replace_vars_with_values(self, var_dict, json_obj):
"""Expand template variables in-place within a manifest dictionary.
Variables use the marker format ``@<<name>>@`` and are resolved from
``var_dict``.
:param var_dict: Variable name/value mapping.
:param json_obj: JSON-like dictionary to mutate in place.
"""
for key, value in json_obj.items():
if key == "vars":
continue
if isinstance(value, dict):
self._replace_vars_with_values(var_dict, value)
elif isinstance(value, str):
matches = re.findall("@<<(.*)>>@", value)
for m in matches:
if m in var_dict:
value = value.replace(f"@<<{m}>>@", var_dict[m])
json_obj[key] = value
else:
continue
def _resolve_model_config(
self,
model_name: str,
data_dir: str,
visited: set = None,
max_depth: int = 10,
) -> Union[dict, None]:
"""Resolve a model entry to its final configuration dictionary.
Alias chains are resolved recursively with cycle/depth protection.
:param model_name: Model name to resolve (case-insensitive).
:param data_dir: Reserved for compatibility with existing call sites.
:param visited: Set of previously visited model names.
:param max_depth: Maximum alias-chain depth.
:returns: Resolved model configuration dictionary, or ``None`` when the
model name is not present.
:raises InvalidConfigFile: If alias resolution exceeds ``max_depth``.
"""
if visited is None:
visited = set()
if len(visited) >= max_depth:
raise InvalidConfigFile(
f"Maximum alias resolution depth ({max_depth}) exceeded. "
f"Possible circular alias in model manifest {self.model_manifest}. "
f"Resolution chain: {' -> '.join(visited)} -> {model_name}"
)
model_name = model_name.lower()
if model_name not in self.models:
return None
model_entry = self.models[model_name]
visited_copy = visited.copy()
visited_copy.add(model_name)
# If entry is a string, it's an alias reference
if isinstance(model_entry, str):
# Remove optional '@' prefix that marks aliases
target_model_name = (
model_entry[1:] if model_entry.startswith("@") else model_entry
)
# Recursively resolve the target
return self._resolve_model_config(
target_model_name, data_dir, visited_copy, max_depth
)
# Entry is a dict, return it as the configuration
return model_entry if isinstance(model_entry, dict) else None
[docs]
def get_model(
self,
model_name: str = "default",
data_dir: str = ".",
reference_frame: Union[ReferenceFrame, None] = None,
) -> Union[PlateModel, None]:
"""Return a :class:`PlateModel` for ``model_name``.
The method resolves aliases, applies optional reference-frame handling,
and instantiates :class:`PlateModel` with the resolved configuration.
:param model_name: Model name or alias (case-insensitive). Defaults to
``"default"``.
:param data_dir: Parent directory for model downloads and cache files.
:param reference_frame: Optional reference frame. When PMAG is requested
and a ``_pmag_ref`` variant exists, that variant is selected
automatically.
:returns: A configured :class:`PlateModel`, or ``None`` if the model is
unavailable or incompatible with the requested reference frame.
:raises InvalidConfigFile: If alias resolution detects an invalid alias
chain.
"""
model_name_lower = model_name.lower()
if reference_frame == ReferenceFrame.PmagReferenceFrame:
if f"{model_name_lower}_pmag_ref" in self.models:
model_name_lower += "_pmag_ref"
try:
model_cfg = self._resolve_model_config(model_name_lower, data_dir)
except InvalidConfigFile:
raise
if model_cfg is None:
logger.error(f"Model '{model_name}' is not available.")
return None
if (
reference_frame == ReferenceFrame.PmagReferenceFrame
and not model_name_lower.endswith("_pmag_ref")
):
if (
model_cfg.get("Attributes", {}).get("PmagReferenceFrameAnchorPID")
is None
):
logger.error(
f"Model '{model_name}' does not have a PMAG reference frame version available."
)
return None
return PlateModel(
model_name_lower,
model_cfg=model_cfg,
data_dir=data_dir,
reference_frame=reference_frame,
)
[docs]
def get_available_model_names(self):
"""Return all model keys from the loaded manifest.
:returns: Available model names and aliases.
:rtype: list[str]
"""
return list(self.models.keys())
[docs]
@staticmethod
def get_local_available_model_names(local_dir: str):
"""Return locally available model names from ``local_dir``.
:param local_dir: The local folder containing models.
:type local_dir: str
:returns: Names of subdirectories that look like valid local PMM models
(contain ``.metadata.json``).
:rtype: list[str]
"""
models = []
for file in os.listdir(local_dir):
d = os.path.join(local_dir, file)
if os.path.isdir(d) and os.path.isfile(f"{d}/.metadata.json"):
models.append(file)
return models
[docs]
@staticmethod
def get_default_repo_url():
"""Return the first reachable default model-manifest URL.
Endpoints are probed in order using HTTP ``HEAD`` requests.
:returns: Reachable manifest URL.
:rtype: str
:raises ServerUnavailable: If none of the default endpoints are
reachable.
"""
default_repo_url_list = [
"https://repo.gplates.org/webdav/pmm/config/models_v2.json",
"https://www.earthbyte.org/webdav/pmm/config/models_v2_eb.json",
"https://portal.gplates.org/static/pmm/config/models_v2_gp.json",
]
for url in default_repo_url_list:
try:
response = requests.head(url, timeout=(5, 5))
if response.status_code == 200:
return url
else:
logger.warning(
f"Unable to fetch {url}. status_code={response.status_code}"
)
continue
except:
logger.warning(f"Unable to fetch {url}.")
continue
raise ServerUnavailable(
"""Cannot connect to the servers. Either the servers are currently unavailable, or there is a problem with your internet connection."""
)
[docs]
def download_all_models(self, data_dir: str = "./") -> None:
"""Download layer data for all available models into ``data_dir``.
:param data_dir: Destination directory for downloaded model data.
:type data_dir: str
"""
for name in self.get_available_model_names():
print(f"download {name}")
model = self.get_model(name)
if model is not None:
model.set_data_dir(data_dir)
model.download_all_layers()