master
/ miniconda3 / envs / poem / lib / python3.10 / site-packages / triton / impl / base.py

base.py @a8e0244 raw · history · blame

from __future__ import annotations

from functools import wraps
from typing import TypeVar

T = TypeVar("T")

TRITON_BUILTIN = "__triton_builtin__"


def builtin(fn: T) -> T:
    """Mark a function as a builtin."""
    assert callable(fn)

    @wraps(fn)
    def wrapper(*args, **kwargs):
        if "_builder" not in kwargs or kwargs["_builder"] is None:
            raise ValueError(
                "Did you forget to add @triton.jit ? "
                "(`_builder` argument must be provided outside of JIT functions.)"
            )
        return fn(*args, **kwargs)

    setattr(wrapper, TRITON_BUILTIN, True)

    return wrapper


def is_builtin(fn) -> bool:
    """Is this a registered triton builtin function?"""
    return getattr(fn, TRITON_BUILTIN, False)


def extern(fn: T) -> T:
    """A decorator for external functions."""
    return builtin(fn)