Skip to content

anypinn.lib.utils

T = TypeVar('T') module-attribute

find(iterable: Iterable[T], predicate: Callable[[T], bool], default: T | None = None) -> T | None

Find the first element in an iterable that satisfies a predicate.

Parameters:

Name Type Description Default
iterable Iterable[T]

The iterable to search.

required
predicate Callable[[T], bool]

A function that returns True for the desired element.

required
default T | None

The value to return if no element is found. Defaults to None.

None

Returns:

Type Description
T | None

The first matching element, or the default value.

Source code in src/anypinn/lib/utils.py
def find(
    iterable: Iterable[T],
    predicate: Callable[[T], bool],
    default: T | None = None,
) -> T | None:
    """
    Find the first element in an iterable that satisfies a predicate.

    Args:
        iterable: The iterable to search.
        predicate: A function that returns True for the desired element.
        default: The value to return if no element is found. Defaults to None.

    Returns:
        The first matching element, or the default value.
    """
    return next((x for x in iterable if predicate(x)), default)

find_or_raise(iterable: Iterable[T], predicate: Callable[[T], bool], exception: Exception | Callable[[], Exception] | None = None) -> T

Find the first element in an iterable that satisfies a predicate, or raise an exception.

Parameters:

Name Type Description Default
iterable Iterable[T]

The iterable to search.

required
predicate Callable[[T], bool]

A function that returns True for the desired element.

required
exception Exception | Callable[[], Exception] | None

The exception to raise if no element is found. Can be an Exception instance, a callable returning an Exception, or None (raises ValueError).

None

Returns:

Type Description
T

The first matching element.

Raises:

Type Description
ValueError

If no element is found and no specific exception is provided.

Exception

The provided exception if no element is found.

Source code in src/anypinn/lib/utils.py
def find_or_raise(
    iterable: Iterable[T],
    predicate: Callable[[T], bool],
    exception: Exception | Callable[[], Exception] | None = None,
) -> T:
    """
    Find the first element in an iterable that satisfies a predicate, or raise an exception.

    Args:
        iterable: The iterable to search.
        predicate: A function that returns True for the desired element.
        exception: The exception to raise if no element is found.
                   Can be an Exception instance, a callable returning an Exception,
                   or None (raises ValueError).

    Returns:
        The first matching element.

    Raises:
        ValueError: If no element is found and no specific exception is provided.
        Exception: The provided exception if no element is found.
    """
    found = find(iterable, predicate)
    if found is not None:
        return found

    if exception is None:
        raise ValueError("Element not found")
    if isinstance(exception, Exception):
        raise exception
    raise exception()

get_tensorboard_logger(trainer: Trainer, default: TensorBoardLogger | None = None) -> TensorBoardLogger | None

Retrieve the TensorBoardLogger from the trainer.

Parameters:

Name Type Description Default
trainer Trainer

The PyTorch Lightning Trainer instance.

required
default TensorBoardLogger | None

Default value if not found.

None

Returns:

Type Description
TensorBoardLogger | None

The TensorBoardLogger or the default value.

Source code in src/anypinn/lib/utils.py
def get_tensorboard_logger(
    trainer: Trainer,
    default: TensorBoardLogger | None = None,
) -> TensorBoardLogger | None:
    """
    Retrieve the TensorBoardLogger from the trainer.

    Args:
        trainer: The PyTorch Lightning Trainer instance.
        default: Default value if not found.

    Returns:
        The TensorBoardLogger or the default value.
    """
    return cast(
        TensorBoardLogger | None,
        find(
            trainer.loggers,
            lambda l: isinstance(l, TensorBoardLogger),
            default,
        ),
    )

get_tensorboard_logger_or_raise(trainer: Trainer) -> TensorBoardLogger

Retrieve the TensorBoardLogger from the trainer, or raise if not present.

Parameters:

Name Type Description Default
trainer Trainer

The PyTorch Lightning Trainer instance.

required

Returns:

Type Description
TensorBoardLogger

The TensorBoardLogger.

Raises:

Type Description
ValueError

If no TensorBoardLogger is attached to the trainer.

Source code in src/anypinn/lib/utils.py
def get_tensorboard_logger_or_raise(trainer: Trainer) -> TensorBoardLogger:
    """
    Retrieve the TensorBoardLogger from the trainer, or raise if not present.

    Args:
        trainer: The PyTorch Lightning Trainer instance.

    Returns:
        The TensorBoardLogger.

    Raises:
        ValueError: If no TensorBoardLogger is attached to the trainer.
    """
    return cast(
        TensorBoardLogger,
        find_or_raise(
            trainer.loggers,
            lambda l: isinstance(l, TensorBoardLogger),
            ValueError("TensorBoard logger not found"),
        ),
    )