Source code for rshf.utils

from huggingface_hub import get_collection, hf_hub_download
import json
import ssl

#HACK: Ignore SSL errors
ssl._create_default_https_context = ssl._create_unverified_context


[docs]def list_models(model_name): collection = get_collection("MVRL/remote-sensing-foundation-models-664e8fcd67d8ca8c03f42d00") models = filter(lambda item: model_name.lower() in item.item_id.lower(), collection.items) print(f"Available {model_name} pretrained models:\n") for model_info in models: print(model_info.item_id)
[docs]def help(model): print(model.__doc__)
[docs]def from_config(model_class, repo_id, revision=None, **kwargs): """Load a model with randomly initialized weights using the architecture configuration stored in a HuggingFace Hub repository. This is useful for training a model from scratch while still using the same architecture as a known pretrained checkpoint. Args: model_class: The model class to instantiate (e.g. ``SatMAE``). repo_id (str): HuggingFace Hub repository ID (e.g. ``"MVRL/satmae-vitlarge-fmow-pretrain-800"``). revision (str, optional): Branch, tag, or commit hash to use. Defaults to the latest revision. **kwargs: Additional keyword arguments that override values read from the repository's ``config.json``. These must be valid parameters for ``model_class.__init__``; unknown parameters will raise an error when the model is instantiated. Returns: An instance of ``model_class`` with randomly initialized weights. Raises: huggingface_hub.utils.EntryNotFoundError: If ``config.json`` is not found in the repository. huggingface_hub.utils.RepositoryNotFoundError: If ``repo_id`` does not exist or is not accessible. Example: >>> from rshf import from_config >>> from rshf.satmae import SatMAE >>> model = from_config(SatMAE, "MVRL/satmae-vitlarge-fmow-pretrain-800") """ try: config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision) except Exception as e: raise type(e)( f"Could not download config.json from '{repo_id}'. " f"Ensure the repository exists and contains a config.json file. " f"Original error: {e}" ) from e with open(config_path) as f: config = json.load(f) # Remove internal HuggingFace Hub metadata keys (prefixed with "_") config = {k: v for k, v in config.items() if not k.startswith("_")} config.update(kwargs) return model_class(**config)