Skip to content

Primitives

Geometric primitives inheriting from equinox.Module with SDF, parameters, and volume interface.

brepax.primitives

Geometric primitives with SDF interface.

BSplineSurface

Bases: Primitive

B-spline surface defined by a control point grid and knot vectors.

The SDF is computed via closest-point projection onto the surface. Control points are differentiable design variables: jax.grad flows through the unrolled Newton projection.

Attributes:

Name Type Description
control_points Float[Array, 'nu nv 3']

Control point grid, shape (n_u, n_v, 3).

knots_u Array

Knot vector in u-direction.

knots_v Array

Knot vector in v-direction.

degree_u int

Polynomial degree in u.

degree_v int

Polynomial degree in v.

Examples:

>>> import jax.numpy as jnp
>>> pts = jnp.array([[[0,0,0],[1,0,0]],
...                  [[0,1,0],[1,1,0]]], dtype=float)
>>> knots = jnp.array([0., 0., 1., 1.])
>>> surf = BSplineSurface(
...     control_points=pts, knots_u=knots, knots_v=knots,
...     degree_u=1, degree_v=1,
... )
>>> d = surf.sdf(jnp.array([0.5, 0.5, 1.0]))
Source code in src/brepax/primitives/bspline_surface.py
class BSplineSurface(Primitive):
    """B-spline surface defined by a control point grid and knot vectors.

    The SDF is computed via closest-point projection onto the surface.
    Control points are differentiable design variables: ``jax.grad``
    flows through the unrolled Newton projection.

    Attributes:
        control_points: Control point grid, shape ``(n_u, n_v, 3)``.
        knots_u: Knot vector in u-direction.
        knots_v: Knot vector in v-direction.
        degree_u: Polynomial degree in u.
        degree_v: Polynomial degree in v.

    Examples:
        >>> import jax.numpy as jnp
        >>> pts = jnp.array([[[0,0,0],[1,0,0]],
        ...                  [[0,1,0],[1,1,0]]], dtype=float)
        >>> knots = jnp.array([0., 0., 1., 1.])
        >>> surf = BSplineSurface(
        ...     control_points=pts, knots_u=knots, knots_v=knots,
        ...     degree_u=1, degree_v=1,
        ... )
        >>> d = surf.sdf(jnp.array([0.5, 0.5, 1.0]))
    """

    control_points: Float[Array, "nu nv 3"]
    knots_u: Array = eqx.field()
    knots_v: Array = eqx.field()
    degree_u: int = eqx.field(static=True)
    degree_v: int = eqx.field(static=True)
    weights: Array | None = eqx.field(default=None)
    sign_flip: float = eqx.field(default=1.0, static=True)
    param_u_range: tuple[float, float] | None = eqx.field(default=None, static=True)
    param_v_range: tuple[float, float] | None = eqx.field(default=None, static=True)
    coarse_positions: Array | None = eqx.field(default=None)
    coarse_normals: Array | None = eqx.field(default=None)
    trim_polygon: Array | None = eqx.field(default=None)
    trim_mask: Array | None = eqx.field(default=None)

    def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
        """Signed distance from query points to the B-spline surface.

        For batched inputs, each point is projected independently via
        Newton iteration.
        """
        shape = x.shape[:-1]
        flat = x.reshape(-1, 3)

        # Coarse grid sampling: compute once, reuse for all query points
        from brepax.nurbs.projection import _COARSE_GRID

        u_lo = self.knots_u[self.degree_u]
        u_hi = self.knots_u[-self.degree_u - 1]
        v_lo = self.knots_v[self.degree_v]
        v_hi = self.knots_v[-self.degree_v - 1]
        us = jnp.linspace(u_lo, u_hi, _COARSE_GRID)
        vs = jnp.linspace(v_lo, v_hi, _COARSE_GRID)
        u_grid, v_grid = jnp.meshgrid(us, vs, indexing="ij")
        u_flat_g = u_grid.ravel()
        v_flat_g = v_grid.ravel()

        def _eval_sample(u: Array, v: Array) -> Array:
            return evaluate_surface(
                self.control_points,
                self.knots_u,
                self.knots_v,
                self.degree_u,
                self.degree_v,
                u,
                v,
                self.weights,
            )

        samples = jax.vmap(_eval_sample)(u_flat_g, v_flat_g)

        def _single_sdf(q: Array) -> Array:
            # Find closest coarse sample (stop_gradient: argmin is non-diff)
            dists = jnp.sum((samples - q) ** 2, axis=-1)
            best = jnp.argmin(dists)
            u0 = jax.lax.stop_gradient(u_flat_g[best])
            v0 = jax.lax.stop_gradient(v_flat_g[best])
            return bspline_sdf(
                q,
                self.control_points,
                self.knots_u,
                self.knots_v,
                self.degree_u,
                self.degree_v,
                u0=u0,
                v0=v0,
                weights=self.weights,
                param_u_range=self.param_u_range,
                param_v_range=self.param_v_range,
                sign_flip=self.sign_flip,
                coarse_positions=self.coarse_positions,
                coarse_normals=self.coarse_normals,
            )

        result = jax.vmap(_single_sdf)(flat)
        return result.reshape(shape)

    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        params: dict[str, Array] = {
            "control_points": self.control_points,
            "knots_u": self.knots_u,
            "knots_v": self.knots_v,
        }
        if self.weights is not None:
            params["weights"] = self.weights
        return params

parameters()

Return differentiable design parameters.

Source code in src/brepax/primitives/bspline_surface.py
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    params: dict[str, Array] = {
        "control_points": self.control_points,
        "knots_u": self.knots_u,
        "knots_v": self.knots_v,
    }
    if self.weights is not None:
        params["weights"] = self.weights
    return params

sdf(x)

Signed distance from query points to the B-spline surface.

For batched inputs, each point is projected independently via Newton iteration.

Source code in src/brepax/primitives/bspline_surface.py
def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
    """Signed distance from query points to the B-spline surface.

    For batched inputs, each point is projected independently via
    Newton iteration.
    """
    shape = x.shape[:-1]
    flat = x.reshape(-1, 3)

    # Coarse grid sampling: compute once, reuse for all query points
    from brepax.nurbs.projection import _COARSE_GRID

    u_lo = self.knots_u[self.degree_u]
    u_hi = self.knots_u[-self.degree_u - 1]
    v_lo = self.knots_v[self.degree_v]
    v_hi = self.knots_v[-self.degree_v - 1]
    us = jnp.linspace(u_lo, u_hi, _COARSE_GRID)
    vs = jnp.linspace(v_lo, v_hi, _COARSE_GRID)
    u_grid, v_grid = jnp.meshgrid(us, vs, indexing="ij")
    u_flat_g = u_grid.ravel()
    v_flat_g = v_grid.ravel()

    def _eval_sample(u: Array, v: Array) -> Array:
        return evaluate_surface(
            self.control_points,
            self.knots_u,
            self.knots_v,
            self.degree_u,
            self.degree_v,
            u,
            v,
            self.weights,
        )

    samples = jax.vmap(_eval_sample)(u_flat_g, v_flat_g)

    def _single_sdf(q: Array) -> Array:
        # Find closest coarse sample (stop_gradient: argmin is non-diff)
        dists = jnp.sum((samples - q) ** 2, axis=-1)
        best = jnp.argmin(dists)
        u0 = jax.lax.stop_gradient(u_flat_g[best])
        v0 = jax.lax.stop_gradient(v_flat_g[best])
        return bspline_sdf(
            q,
            self.control_points,
            self.knots_u,
            self.knots_v,
            self.degree_u,
            self.degree_v,
            u0=u0,
            v0=v0,
            weights=self.weights,
            param_u_range=self.param_u_range,
            param_v_range=self.param_v_range,
            sign_flip=self.sign_flip,
            coarse_positions=self.coarse_positions,
            coarse_normals=self.coarse_normals,
        )

    result = jax.vmap(_single_sdf)(flat)
    return result.reshape(shape)

Box

Bases: Primitive

An axis-aligned box defined by center and half-extents.

Attributes:

Name Type Description
center Float[Array, 3]

Center of the box (3,).

half_extents Float[Array, 3]

Half-size in each dimension (3,).

Source code in src/brepax/primitives/box.py
class Box(Primitive):
    """An axis-aligned box defined by center and half-extents.

    Attributes:
        center: Center of the box (3,).
        half_extents: Half-size in each dimension (3,).
    """

    center: Float[Array, "3"]
    half_extents: Float[Array, "3"]

    def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
        """Signed distance from query points to the box surface."""
        q = jnp.abs(x - self.center) - self.half_extents
        qp = jnp.maximum(q, 0.0)
        # eps avoids NaN gradient from norm at zero (same pattern as FiniteCylinder)
        outside = jnp.sqrt(jnp.sum(qp**2, axis=-1) + 1e-20)
        is_outside = jnp.max(q, axis=-1) > 0
        outside = jnp.where(is_outside, outside, 0.0)
        inside = jnp.minimum(jnp.max(q, axis=-1), 0.0)
        return outside + inside

    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        return {"center": self.center, "half_extents": self.half_extents}

    def volume(self) -> Float[Array, ""]:
        """Box volume: 8 * hx * hy * hz."""
        return 8.0 * jnp.prod(self.half_extents)

parameters()

Return differentiable design parameters.

Source code in src/brepax/primitives/box.py
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    return {"center": self.center, "half_extents": self.half_extents}

sdf(x)

Signed distance from query points to the box surface.

Source code in src/brepax/primitives/box.py
def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
    """Signed distance from query points to the box surface."""
    q = jnp.abs(x - self.center) - self.half_extents
    qp = jnp.maximum(q, 0.0)
    # eps avoids NaN gradient from norm at zero (same pattern as FiniteCylinder)
    outside = jnp.sqrt(jnp.sum(qp**2, axis=-1) + 1e-20)
    is_outside = jnp.max(q, axis=-1) > 0
    outside = jnp.where(is_outside, outside, 0.0)
    inside = jnp.minimum(jnp.max(q, axis=-1), 0.0)
    return outside + inside

volume()

Box volume: 8 * hx * hy * hz.

Source code in src/brepax/primitives/box.py
def volume(self) -> Float[Array, ""]:
    """Box volume: 8 * hx * hy * hz."""
    return 8.0 * jnp.prod(self.half_extents)

Cone

Bases: Primitive

An infinite cone defined by apex position, axis direction, and half-angle.

The cone extends infinitely from the apex along the axis direction. The SDF is positive outside, negative inside.

Attributes:

Name Type Description
apex Float[Array, 3]

Apex point of the cone (3,).

axis Float[Array, 3]

Unit direction vector from apex (3,). Must be normalized.

angle Float[Array, '']

Half-angle in radians (0, pi/2).

Source code in src/brepax/primitives/cone.py
class Cone(Primitive):
    """An infinite cone defined by apex position, axis direction, and half-angle.

    The cone extends infinitely from the apex along the axis direction.
    The SDF is positive outside, negative inside.

    Attributes:
        apex: Apex point of the cone (3,).
        axis: Unit direction vector from apex (3,). Must be normalized.
        angle: Half-angle in radians (0, pi/2).
    """

    apex: Float[Array, "3"]
    axis: Float[Array, "3"]
    angle: Float[Array, ""]

    def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
        """Signed distance from query points to the cone surface."""
        v = x - self.apex
        # Distance along axis
        h = jnp.sum(v * self.axis, axis=-1)
        # Perpendicular distance from axis
        perp = jnp.sqrt(jnp.maximum(jnp.sum(v * v, axis=-1) - h**2, 1e-20))
        sin_a = jnp.sin(self.angle)
        cos_a = jnp.cos(self.angle)
        # Signed distance: positive outside cone, negative inside
        return jnp.where(
            h >= 0,
            perp * cos_a - h * sin_a,
            jnp.sqrt(perp**2 + h**2),
        )

    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        return {"apex": self.apex, "axis": self.axis, "angle": self.angle}

parameters()

Return differentiable design parameters.

Source code in src/brepax/primitives/cone.py
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    return {"apex": self.apex, "axis": self.axis, "angle": self.angle}

sdf(x)

Signed distance from query points to the cone surface.

Source code in src/brepax/primitives/cone.py
def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
    """Signed distance from query points to the cone surface."""
    v = x - self.apex
    # Distance along axis
    h = jnp.sum(v * self.axis, axis=-1)
    # Perpendicular distance from axis
    perp = jnp.sqrt(jnp.maximum(jnp.sum(v * v, axis=-1) - h**2, 1e-20))
    sin_a = jnp.sin(self.angle)
    cos_a = jnp.cos(self.angle)
    # Signed distance: positive outside cone, negative inside
    return jnp.where(
        h >= 0,
        perp * cos_a - h * sin_a,
        jnp.sqrt(perp**2 + h**2),
    )

Cylinder

Bases: Primitive

An infinite cylinder defined by a point on the axis, axis direction, and radius.

The SDF measures perpendicular distance from the axis line minus the radius. The cylinder extends infinitely along the axis direction.

Attributes:

Name Type Description
point Float[Array, 3]

A point on the cylinder axis (3,).

axis Float[Array, 3]

Unit direction vector of the axis (3,). Must be normalized.

radius Float[Array, '']

Scalar radius (must be positive).

Source code in src/brepax/primitives/cylinder.py
class Cylinder(Primitive):
    """An infinite cylinder defined by a point on the axis, axis direction, and radius.

    The SDF measures perpendicular distance from the axis line minus the radius.
    The cylinder extends infinitely along the axis direction.

    Attributes:
        point: A point on the cylinder axis (3,).
        axis: Unit direction vector of the axis (3,). Must be normalized.
        radius: Scalar radius (must be positive).
    """

    point: Float[Array, "3"]
    axis: Float[Array, "3"]
    radius: Float[Array, ""]

    def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
        """Signed distance from query points to the cylinder surface.

        Computes the perpendicular distance from each query point to
        the cylinder axis, then subtracts the radius.
        """
        # Vector from axis point to query
        v = x - self.point
        # Project onto axis to get parallel component
        proj_len = jnp.sum(v * self.axis, axis=-1, keepdims=True)
        # Perpendicular component
        perp = v - proj_len * self.axis
        perp_dist = jnp.linalg.norm(perp, axis=-1)
        return perp_dist - self.radius  # type: ignore[no-any-return]

    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        return {"point": self.point, "axis": self.axis, "radius": self.radius}

parameters()

Return differentiable design parameters.

Source code in src/brepax/primitives/cylinder.py
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    return {"point": self.point, "axis": self.axis, "radius": self.radius}

sdf(x)

Signed distance from query points to the cylinder surface.

Computes the perpendicular distance from each query point to the cylinder axis, then subtracts the radius.

Source code in src/brepax/primitives/cylinder.py
def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
    """Signed distance from query points to the cylinder surface.

    Computes the perpendicular distance from each query point to
    the cylinder axis, then subtracts the radius.
    """
    # Vector from axis point to query
    v = x - self.point
    # Project onto axis to get parallel component
    proj_len = jnp.sum(v * self.axis, axis=-1, keepdims=True)
    # Perpendicular component
    perp = v - proj_len * self.axis
    perp_dist = jnp.linalg.norm(perp, axis=-1)
    return perp_dist - self.radius  # type: ignore[no-any-return]

Disk

Bases: Primitive

A 2D disk defined by center and radius.

Attributes:

Name Type Description
center Float[Array, 2]

Center coordinates (2,).

radius Float[Array, '']

Scalar radius (must be positive).

Source code in src/brepax/primitives/disk.py
class Disk(Primitive):
    """A 2D disk defined by center and radius.

    Attributes:
        center: Center coordinates (2,).
        radius: Scalar radius (must be positive).
    """

    center: Float[Array, "2"]
    radius: Float[Array, ""]

    def sdf(self, x: Float[Array, "... 2"]) -> Float[Array, "..."]:
        """Signed distance from query points to the disk boundary."""
        return jnp.linalg.norm(x - self.center, axis=-1) - self.radius  # type: ignore[no-any-return]

    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        return {"center": self.center, "radius": self.radius}

    def volume(self) -> Float[Array, ""]:
        """Disk area: pi * r^2."""
        return jnp.pi * self.radius**2

parameters()

Return differentiable design parameters.

Source code in src/brepax/primitives/disk.py
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    return {"center": self.center, "radius": self.radius}

sdf(x)

Signed distance from query points to the disk boundary.

Source code in src/brepax/primitives/disk.py
def sdf(self, x: Float[Array, "... 2"]) -> Float[Array, "..."]:
    """Signed distance from query points to the disk boundary."""
    return jnp.linalg.norm(x - self.center, axis=-1) - self.radius  # type: ignore[no-any-return]

volume()

Disk area: pi * r^2.

Source code in src/brepax/primitives/disk.py
def volume(self) -> Float[Array, ""]:
    """Disk area: pi * r^2."""
    return jnp.pi * self.radius**2

FiniteCylinder

Bases: Primitive

A finite cylinder (capped) defined by center, axis, radius, and height.

The cylinder is centered at center and extends height/2 in each direction along axis.

Attributes:

Name Type Description
center Float[Array, 3]

Center of the cylinder (3,).

axis Float[Array, 3]

Unit direction vector of the axis (3,). Must be normalized.

radius Float[Array, '']

Cylinder radius (must be positive).

height Float[Array, '']

Total height along the axis (must be positive).

Source code in src/brepax/primitives/finite_cylinder.py
class FiniteCylinder(Primitive):
    """A finite cylinder (capped) defined by center, axis, radius, and height.

    The cylinder is centered at `center` and extends `height/2` in each
    direction along `axis`.

    Attributes:
        center: Center of the cylinder (3,).
        axis: Unit direction vector of the axis (3,). Must be normalized.
        radius: Cylinder radius (must be positive).
        height: Total height along the axis (must be positive).
    """

    center: Float[Array, "3"]
    axis: Float[Array, "3"]
    radius: Float[Array, ""]
    height: Float[Array, ""]

    def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
        """Signed distance from query points to the finite cylinder surface."""
        v = x - self.center
        # Axial distance from center
        h = jnp.sum(v * self.axis, axis=-1)
        # Perpendicular distance from axis
        perp_sq = jnp.sum(v * v, axis=-1) - h**2
        perp = jnp.sqrt(jnp.maximum(perp_sq, 1e-20))

        # Signed distances to the two constraints
        d_radial = perp - self.radius
        d_axial = jnp.abs(h) - self.height / 2.0

        # SDF of intersection of infinite cylinder and slab.
        # Use sqrt with eps to avoid NaN gradient at [0, 0].
        dr = jnp.maximum(d_radial, 0.0)
        da = jnp.maximum(d_axial, 0.0)
        outside = jnp.sqrt(dr**2 + da**2 + 1e-20)
        # Subtract the eps contribution when truly outside
        outside = jnp.where((dr > 0) | (da > 0), outside, 0.0)
        inside = jnp.minimum(jnp.maximum(d_radial, d_axial), 0.0)
        return outside + inside

    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        return {
            "center": self.center,
            "axis": self.axis,
            "radius": self.radius,
            "height": self.height,
        }

    def volume(self) -> Float[Array, ""]:
        """Finite cylinder volume: pi * r^2 * h."""
        return jnp.pi * self.radius**2 * self.height

parameters()

Return differentiable design parameters.

Source code in src/brepax/primitives/finite_cylinder.py
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    return {
        "center": self.center,
        "axis": self.axis,
        "radius": self.radius,
        "height": self.height,
    }

sdf(x)

Signed distance from query points to the finite cylinder surface.

Source code in src/brepax/primitives/finite_cylinder.py
def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
    """Signed distance from query points to the finite cylinder surface."""
    v = x - self.center
    # Axial distance from center
    h = jnp.sum(v * self.axis, axis=-1)
    # Perpendicular distance from axis
    perp_sq = jnp.sum(v * v, axis=-1) - h**2
    perp = jnp.sqrt(jnp.maximum(perp_sq, 1e-20))

    # Signed distances to the two constraints
    d_radial = perp - self.radius
    d_axial = jnp.abs(h) - self.height / 2.0

    # SDF of intersection of infinite cylinder and slab.
    # Use sqrt with eps to avoid NaN gradient at [0, 0].
    dr = jnp.maximum(d_radial, 0.0)
    da = jnp.maximum(d_axial, 0.0)
    outside = jnp.sqrt(dr**2 + da**2 + 1e-20)
    # Subtract the eps contribution when truly outside
    outside = jnp.where((dr > 0) | (da > 0), outside, 0.0)
    inside = jnp.minimum(jnp.maximum(d_radial, d_axial), 0.0)
    return outside + inside

volume()

Finite cylinder volume: pi * r^2 * h.

Source code in src/brepax/primitives/finite_cylinder.py
def volume(self) -> Float[Array, ""]:
    """Finite cylinder volume: pi * r^2 * h."""
    return jnp.pi * self.radius**2 * self.height

Plane

Bases: Primitive

An infinite plane (half-space boundary) defined by normal and offset.

The SDF is positive on the side the normal points toward (outside), and negative on the opposite side (inside).

Attributes:

Name Type Description
normal Float[Array, 3]

Unit normal vector (3,). Must be normalized.

offset Float[Array, '']

Signed distance from origin to the plane along the normal.

Source code in src/brepax/primitives/plane.py
class Plane(Primitive):
    """An infinite plane (half-space boundary) defined by normal and offset.

    The SDF is positive on the side the normal points toward (outside),
    and negative on the opposite side (inside).

    Attributes:
        normal: Unit normal vector (3,). Must be normalized.
        offset: Signed distance from origin to the plane along the normal.
    """

    normal: Float[Array, "3"]
    offset: Float[Array, ""]

    def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
        """Signed distance from query points to the plane."""
        return jnp.sum(x * self.normal, axis=-1) - self.offset

    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        return {"normal": self.normal, "offset": self.offset}

parameters()

Return differentiable design parameters.

Source code in src/brepax/primitives/plane.py
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    return {"normal": self.normal, "offset": self.offset}

sdf(x)

Signed distance from query points to the plane.

Source code in src/brepax/primitives/plane.py
def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
    """Signed distance from query points to the plane."""
    return jnp.sum(x * self.normal, axis=-1) - self.offset

Primitive

Bases: Module

Base class for all geometric primitives.

Each primitive exposes an SDF, gradient w.r.t. spatial coordinates, and a list of design parameters that gradients can flow into.

Source code in src/brepax/primitives/_base.py
class Primitive(eqx.Module):
    """Base class for all geometric primitives.

    Each primitive exposes an SDF, gradient w.r.t. spatial coordinates,
    and a list of design parameters that gradients can flow into.
    """

    @abstractmethod
    def sdf(self, x: Float[Array, "... dim"]) -> Float[Array, "..."]:
        """Evaluate the signed distance function at query points."""
        raise NotImplementedError

    @abstractmethod
    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        raise NotImplementedError

    def volume(self) -> Float[Array, ""]:
        """Analytical volume of this primitive.

        Returns the finite volume for bounded primitives (Sphere, Box, etc.).
        Unbounded primitives (Cylinder, Plane, Cone) return inf.
        Override in subclasses with known analytical formulas.
        Differentiable via jax.grad for gradient computation.
        """
        # Default: unbounded primitive has infinite volume
        import jax.numpy as jnp

        return jnp.array(jnp.inf)

parameters() abstractmethod

Return differentiable design parameters.

Source code in src/brepax/primitives/_base.py
@abstractmethod
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    raise NotImplementedError

sdf(x) abstractmethod

Evaluate the signed distance function at query points.

Source code in src/brepax/primitives/_base.py
@abstractmethod
def sdf(self, x: Float[Array, "... dim"]) -> Float[Array, "..."]:
    """Evaluate the signed distance function at query points."""
    raise NotImplementedError

volume()

Analytical volume of this primitive.

Returns the finite volume for bounded primitives (Sphere, Box, etc.). Unbounded primitives (Cylinder, Plane, Cone) return inf. Override in subclasses with known analytical formulas. Differentiable via jax.grad for gradient computation.

Source code in src/brepax/primitives/_base.py
def volume(self) -> Float[Array, ""]:
    """Analytical volume of this primitive.

    Returns the finite volume for bounded primitives (Sphere, Box, etc.).
    Unbounded primitives (Cylinder, Plane, Cone) return inf.
    Override in subclasses with known analytical formulas.
    Differentiable via jax.grad for gradient computation.
    """
    # Default: unbounded primitive has infinite volume
    import jax.numpy as jnp

    return jnp.array(jnp.inf)

Sphere

Bases: Primitive

A 3D sphere defined by center and radius.

Attributes:

Name Type Description
center Float[Array, 3]

Center coordinates (3,).

radius Float[Array, '']

Scalar radius (must be positive).

Source code in src/brepax/primitives/sphere.py
class Sphere(Primitive):
    """A 3D sphere defined by center and radius.

    Attributes:
        center: Center coordinates (3,).
        radius: Scalar radius (must be positive).
    """

    center: Float[Array, "3"]
    radius: Float[Array, ""]

    def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
        """Signed distance from query points to the sphere surface."""
        return jnp.linalg.norm(x - self.center, axis=-1) - self.radius  # type: ignore[no-any-return]

    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        return {"center": self.center, "radius": self.radius}

    def volume(self) -> Float[Array, ""]:
        """Sphere volume: 4/3 * pi * r^3."""
        return (4.0 / 3.0) * jnp.pi * self.radius**3

parameters()

Return differentiable design parameters.

Source code in src/brepax/primitives/sphere.py
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    return {"center": self.center, "radius": self.radius}

sdf(x)

Signed distance from query points to the sphere surface.

Source code in src/brepax/primitives/sphere.py
def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
    """Signed distance from query points to the sphere surface."""
    return jnp.linalg.norm(x - self.center, axis=-1) - self.radius  # type: ignore[no-any-return]

volume()

Sphere volume: 4/3 * pi * r^3.

Source code in src/brepax/primitives/sphere.py
def volume(self) -> Float[Array, ""]:
    """Sphere volume: 4/3 * pi * r^3."""
    return (4.0 / 3.0) * jnp.pi * self.radius**3

Torus

Bases: Primitive

A torus defined by center, axis, major radius, and minor radius.

Attributes:

Name Type Description
center Float[Array, 3]

Center of the torus (3,).

axis Float[Array, 3]

Unit normal to the torus plane (3,). Must be normalized.

major_radius Float[Array, '']

Distance from center to tube center.

minor_radius Float[Array, '']

Tube radius.

Source code in src/brepax/primitives/torus.py
class Torus(Primitive):
    """A torus defined by center, axis, major radius, and minor radius.

    Attributes:
        center: Center of the torus (3,).
        axis: Unit normal to the torus plane (3,). Must be normalized.
        major_radius: Distance from center to tube center.
        minor_radius: Tube radius.
    """

    center: Float[Array, "3"]
    axis: Float[Array, "3"]
    major_radius: Float[Array, ""]
    minor_radius: Float[Array, ""]

    def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
        """Signed distance from query points to the torus surface."""
        v = x - self.center
        # Component along axis (height above torus plane)
        h = jnp.sum(v * self.axis, axis=-1)
        # Component in the torus plane (distance from center axis)
        v_plane_sq = jnp.sum(v * v, axis=-1) - h**2
        v_plane = jnp.sqrt(jnp.maximum(v_plane_sq, 1e-20))
        # Distance from the tube center ring
        q = jnp.sqrt((v_plane - self.major_radius) ** 2 + h**2)
        return q - self.minor_radius

    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        return {
            "center": self.center,
            "major_radius": self.major_radius,
            "minor_radius": self.minor_radius,
        }

    def volume(self) -> Float[Array, ""]:
        """Torus volume: 2 * pi^2 * R * r^2."""
        return 2.0 * jnp.pi**2 * self.major_radius * self.minor_radius**2

parameters()

Return differentiable design parameters.

Source code in src/brepax/primitives/torus.py
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    return {
        "center": self.center,
        "major_radius": self.major_radius,
        "minor_radius": self.minor_radius,
    }

sdf(x)

Signed distance from query points to the torus surface.

Source code in src/brepax/primitives/torus.py
def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
    """Signed distance from query points to the torus surface."""
    v = x - self.center
    # Component along axis (height above torus plane)
    h = jnp.sum(v * self.axis, axis=-1)
    # Component in the torus plane (distance from center axis)
    v_plane_sq = jnp.sum(v * v, axis=-1) - h**2
    v_plane = jnp.sqrt(jnp.maximum(v_plane_sq, 1e-20))
    # Distance from the tube center ring
    q = jnp.sqrt((v_plane - self.major_radius) ** 2 + h**2)
    return q - self.minor_radius

volume()

Torus volume: 2 * pi^2 * R * r^2.

Source code in src/brepax/primitives/torus.py
def volume(self) -> Float[Array, ""]:
    """Torus volume: 2 * pi^2 * R * r^2."""
    return 2.0 * jnp.pi**2 * self.major_radius * self.minor_radius**2

box

3D axis-aligned box primitive defined by center and half-extents.

Box

Bases: Primitive

An axis-aligned box defined by center and half-extents.

Attributes:

Name Type Description
center Float[Array, 3]

Center of the box (3,).

half_extents Float[Array, 3]

Half-size in each dimension (3,).

Source code in src/brepax/primitives/box.py
class Box(Primitive):
    """An axis-aligned box defined by center and half-extents.

    Attributes:
        center: Center of the box (3,).
        half_extents: Half-size in each dimension (3,).
    """

    center: Float[Array, "3"]
    half_extents: Float[Array, "3"]

    def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
        """Signed distance from query points to the box surface."""
        q = jnp.abs(x - self.center) - self.half_extents
        qp = jnp.maximum(q, 0.0)
        # eps avoids NaN gradient from norm at zero (same pattern as FiniteCylinder)
        outside = jnp.sqrt(jnp.sum(qp**2, axis=-1) + 1e-20)
        is_outside = jnp.max(q, axis=-1) > 0
        outside = jnp.where(is_outside, outside, 0.0)
        inside = jnp.minimum(jnp.max(q, axis=-1), 0.0)
        return outside + inside

    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        return {"center": self.center, "half_extents": self.half_extents}

    def volume(self) -> Float[Array, ""]:
        """Box volume: 8 * hx * hy * hz."""
        return 8.0 * jnp.prod(self.half_extents)
parameters()

Return differentiable design parameters.

Source code in src/brepax/primitives/box.py
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    return {"center": self.center, "half_extents": self.half_extents}
sdf(x)

Signed distance from query points to the box surface.

Source code in src/brepax/primitives/box.py
def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
    """Signed distance from query points to the box surface."""
    q = jnp.abs(x - self.center) - self.half_extents
    qp = jnp.maximum(q, 0.0)
    # eps avoids NaN gradient from norm at zero (same pattern as FiniteCylinder)
    outside = jnp.sqrt(jnp.sum(qp**2, axis=-1) + 1e-20)
    is_outside = jnp.max(q, axis=-1) > 0
    outside = jnp.where(is_outside, outside, 0.0)
    inside = jnp.minimum(jnp.max(q, axis=-1), 0.0)
    return outside + inside
volume()

Box volume: 8 * hx * hy * hz.

Source code in src/brepax/primitives/box.py
def volume(self) -> Float[Array, ""]:
    """Box volume: 8 * hx * hy * hz."""
    return 8.0 * jnp.prod(self.half_extents)

bspline_surface

B-spline surface primitive with differentiable SDF.

Wraps the NURBS evaluation and closest-point projection pipeline as a :class:~brepax.primitives.Primitive so that existing Boolean operations and metrics work with NURBS surfaces automatically.

The surface is unbounded (a single open patch, not a closed solid), so volume() returns infinity and Boolean operations fall back to grid-based evaluation, matching the pattern used by :class:~brepax.primitives.Plane and :class:~brepax.primitives.Cylinder.

BSplineSurface

Bases: Primitive

B-spline surface defined by a control point grid and knot vectors.

The SDF is computed via closest-point projection onto the surface. Control points are differentiable design variables: jax.grad flows through the unrolled Newton projection.

Attributes:

Name Type Description
control_points Float[Array, 'nu nv 3']

Control point grid, shape (n_u, n_v, 3).

knots_u Array

Knot vector in u-direction.

knots_v Array

Knot vector in v-direction.

degree_u int

Polynomial degree in u.

degree_v int

Polynomial degree in v.

Examples:

>>> import jax.numpy as jnp
>>> pts = jnp.array([[[0,0,0],[1,0,0]],
...                  [[0,1,0],[1,1,0]]], dtype=float)
>>> knots = jnp.array([0., 0., 1., 1.])
>>> surf = BSplineSurface(
...     control_points=pts, knots_u=knots, knots_v=knots,
...     degree_u=1, degree_v=1,
... )
>>> d = surf.sdf(jnp.array([0.5, 0.5, 1.0]))
Source code in src/brepax/primitives/bspline_surface.py
class BSplineSurface(Primitive):
    """B-spline surface defined by a control point grid and knot vectors.

    The SDF is computed via closest-point projection onto the surface.
    Control points are differentiable design variables: ``jax.grad``
    flows through the unrolled Newton projection.

    Attributes:
        control_points: Control point grid, shape ``(n_u, n_v, 3)``.
        knots_u: Knot vector in u-direction.
        knots_v: Knot vector in v-direction.
        degree_u: Polynomial degree in u.
        degree_v: Polynomial degree in v.

    Examples:
        >>> import jax.numpy as jnp
        >>> pts = jnp.array([[[0,0,0],[1,0,0]],
        ...                  [[0,1,0],[1,1,0]]], dtype=float)
        >>> knots = jnp.array([0., 0., 1., 1.])
        >>> surf = BSplineSurface(
        ...     control_points=pts, knots_u=knots, knots_v=knots,
        ...     degree_u=1, degree_v=1,
        ... )
        >>> d = surf.sdf(jnp.array([0.5, 0.5, 1.0]))
    """

    control_points: Float[Array, "nu nv 3"]
    knots_u: Array = eqx.field()
    knots_v: Array = eqx.field()
    degree_u: int = eqx.field(static=True)
    degree_v: int = eqx.field(static=True)
    weights: Array | None = eqx.field(default=None)
    sign_flip: float = eqx.field(default=1.0, static=True)
    param_u_range: tuple[float, float] | None = eqx.field(default=None, static=True)
    param_v_range: tuple[float, float] | None = eqx.field(default=None, static=True)
    coarse_positions: Array | None = eqx.field(default=None)
    coarse_normals: Array | None = eqx.field(default=None)
    trim_polygon: Array | None = eqx.field(default=None)
    trim_mask: Array | None = eqx.field(default=None)

    def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
        """Signed distance from query points to the B-spline surface.

        For batched inputs, each point is projected independently via
        Newton iteration.
        """
        shape = x.shape[:-1]
        flat = x.reshape(-1, 3)

        # Coarse grid sampling: compute once, reuse for all query points
        from brepax.nurbs.projection import _COARSE_GRID

        u_lo = self.knots_u[self.degree_u]
        u_hi = self.knots_u[-self.degree_u - 1]
        v_lo = self.knots_v[self.degree_v]
        v_hi = self.knots_v[-self.degree_v - 1]
        us = jnp.linspace(u_lo, u_hi, _COARSE_GRID)
        vs = jnp.linspace(v_lo, v_hi, _COARSE_GRID)
        u_grid, v_grid = jnp.meshgrid(us, vs, indexing="ij")
        u_flat_g = u_grid.ravel()
        v_flat_g = v_grid.ravel()

        def _eval_sample(u: Array, v: Array) -> Array:
            return evaluate_surface(
                self.control_points,
                self.knots_u,
                self.knots_v,
                self.degree_u,
                self.degree_v,
                u,
                v,
                self.weights,
            )

        samples = jax.vmap(_eval_sample)(u_flat_g, v_flat_g)

        def _single_sdf(q: Array) -> Array:
            # Find closest coarse sample (stop_gradient: argmin is non-diff)
            dists = jnp.sum((samples - q) ** 2, axis=-1)
            best = jnp.argmin(dists)
            u0 = jax.lax.stop_gradient(u_flat_g[best])
            v0 = jax.lax.stop_gradient(v_flat_g[best])
            return bspline_sdf(
                q,
                self.control_points,
                self.knots_u,
                self.knots_v,
                self.degree_u,
                self.degree_v,
                u0=u0,
                v0=v0,
                weights=self.weights,
                param_u_range=self.param_u_range,
                param_v_range=self.param_v_range,
                sign_flip=self.sign_flip,
                coarse_positions=self.coarse_positions,
                coarse_normals=self.coarse_normals,
            )

        result = jax.vmap(_single_sdf)(flat)
        return result.reshape(shape)

    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        params: dict[str, Array] = {
            "control_points": self.control_points,
            "knots_u": self.knots_u,
            "knots_v": self.knots_v,
        }
        if self.weights is not None:
            params["weights"] = self.weights
        return params
parameters()

Return differentiable design parameters.

Source code in src/brepax/primitives/bspline_surface.py
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    params: dict[str, Array] = {
        "control_points": self.control_points,
        "knots_u": self.knots_u,
        "knots_v": self.knots_v,
    }
    if self.weights is not None:
        params["weights"] = self.weights
    return params
sdf(x)

Signed distance from query points to the B-spline surface.

For batched inputs, each point is projected independently via Newton iteration.

Source code in src/brepax/primitives/bspline_surface.py
def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
    """Signed distance from query points to the B-spline surface.

    For batched inputs, each point is projected independently via
    Newton iteration.
    """
    shape = x.shape[:-1]
    flat = x.reshape(-1, 3)

    # Coarse grid sampling: compute once, reuse for all query points
    from brepax.nurbs.projection import _COARSE_GRID

    u_lo = self.knots_u[self.degree_u]
    u_hi = self.knots_u[-self.degree_u - 1]
    v_lo = self.knots_v[self.degree_v]
    v_hi = self.knots_v[-self.degree_v - 1]
    us = jnp.linspace(u_lo, u_hi, _COARSE_GRID)
    vs = jnp.linspace(v_lo, v_hi, _COARSE_GRID)
    u_grid, v_grid = jnp.meshgrid(us, vs, indexing="ij")
    u_flat_g = u_grid.ravel()
    v_flat_g = v_grid.ravel()

    def _eval_sample(u: Array, v: Array) -> Array:
        return evaluate_surface(
            self.control_points,
            self.knots_u,
            self.knots_v,
            self.degree_u,
            self.degree_v,
            u,
            v,
            self.weights,
        )

    samples = jax.vmap(_eval_sample)(u_flat_g, v_flat_g)

    def _single_sdf(q: Array) -> Array:
        # Find closest coarse sample (stop_gradient: argmin is non-diff)
        dists = jnp.sum((samples - q) ** 2, axis=-1)
        best = jnp.argmin(dists)
        u0 = jax.lax.stop_gradient(u_flat_g[best])
        v0 = jax.lax.stop_gradient(v_flat_g[best])
        return bspline_sdf(
            q,
            self.control_points,
            self.knots_u,
            self.knots_v,
            self.degree_u,
            self.degree_v,
            u0=u0,
            v0=v0,
            weights=self.weights,
            param_u_range=self.param_u_range,
            param_v_range=self.param_v_range,
            sign_flip=self.sign_flip,
            coarse_positions=self.coarse_positions,
            coarse_normals=self.coarse_normals,
        )

    result = jax.vmap(_single_sdf)(flat)
    return result.reshape(shape)

cone

3D infinite cone primitive defined by apex, axis, and half-angle.

Cone

Bases: Primitive

An infinite cone defined by apex position, axis direction, and half-angle.

The cone extends infinitely from the apex along the axis direction. The SDF is positive outside, negative inside.

Attributes:

Name Type Description
apex Float[Array, 3]

Apex point of the cone (3,).

axis Float[Array, 3]

Unit direction vector from apex (3,). Must be normalized.

angle Float[Array, '']

Half-angle in radians (0, pi/2).

Source code in src/brepax/primitives/cone.py
class Cone(Primitive):
    """An infinite cone defined by apex position, axis direction, and half-angle.

    The cone extends infinitely from the apex along the axis direction.
    The SDF is positive outside, negative inside.

    Attributes:
        apex: Apex point of the cone (3,).
        axis: Unit direction vector from apex (3,). Must be normalized.
        angle: Half-angle in radians (0, pi/2).
    """

    apex: Float[Array, "3"]
    axis: Float[Array, "3"]
    angle: Float[Array, ""]

    def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
        """Signed distance from query points to the cone surface."""
        v = x - self.apex
        # Distance along axis
        h = jnp.sum(v * self.axis, axis=-1)
        # Perpendicular distance from axis
        perp = jnp.sqrt(jnp.maximum(jnp.sum(v * v, axis=-1) - h**2, 1e-20))
        sin_a = jnp.sin(self.angle)
        cos_a = jnp.cos(self.angle)
        # Signed distance: positive outside cone, negative inside
        return jnp.where(
            h >= 0,
            perp * cos_a - h * sin_a,
            jnp.sqrt(perp**2 + h**2),
        )

    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        return {"apex": self.apex, "axis": self.axis, "angle": self.angle}
parameters()

Return differentiable design parameters.

Source code in src/brepax/primitives/cone.py
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    return {"apex": self.apex, "axis": self.axis, "angle": self.angle}
sdf(x)

Signed distance from query points to the cone surface.

Source code in src/brepax/primitives/cone.py
def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
    """Signed distance from query points to the cone surface."""
    v = x - self.apex
    # Distance along axis
    h = jnp.sum(v * self.axis, axis=-1)
    # Perpendicular distance from axis
    perp = jnp.sqrt(jnp.maximum(jnp.sum(v * v, axis=-1) - h**2, 1e-20))
    sin_a = jnp.sin(self.angle)
    cos_a = jnp.cos(self.angle)
    # Signed distance: positive outside cone, negative inside
    return jnp.where(
        h >= 0,
        perp * cos_a - h * sin_a,
        jnp.sqrt(perp**2 + h**2),
    )

cylinder

3D infinite cylinder primitive defined by axis, point on axis, and radius.

Cylinder

Bases: Primitive

An infinite cylinder defined by a point on the axis, axis direction, and radius.

The SDF measures perpendicular distance from the axis line minus the radius. The cylinder extends infinitely along the axis direction.

Attributes:

Name Type Description
point Float[Array, 3]

A point on the cylinder axis (3,).

axis Float[Array, 3]

Unit direction vector of the axis (3,). Must be normalized.

radius Float[Array, '']

Scalar radius (must be positive).

Source code in src/brepax/primitives/cylinder.py
class Cylinder(Primitive):
    """An infinite cylinder defined by a point on the axis, axis direction, and radius.

    The SDF measures perpendicular distance from the axis line minus the radius.
    The cylinder extends infinitely along the axis direction.

    Attributes:
        point: A point on the cylinder axis (3,).
        axis: Unit direction vector of the axis (3,). Must be normalized.
        radius: Scalar radius (must be positive).
    """

    point: Float[Array, "3"]
    axis: Float[Array, "3"]
    radius: Float[Array, ""]

    def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
        """Signed distance from query points to the cylinder surface.

        Computes the perpendicular distance from each query point to
        the cylinder axis, then subtracts the radius.
        """
        # Vector from axis point to query
        v = x - self.point
        # Project onto axis to get parallel component
        proj_len = jnp.sum(v * self.axis, axis=-1, keepdims=True)
        # Perpendicular component
        perp = v - proj_len * self.axis
        perp_dist = jnp.linalg.norm(perp, axis=-1)
        return perp_dist - self.radius  # type: ignore[no-any-return]

    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        return {"point": self.point, "axis": self.axis, "radius": self.radius}
parameters()

Return differentiable design parameters.

Source code in src/brepax/primitives/cylinder.py
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    return {"point": self.point, "axis": self.axis, "radius": self.radius}
sdf(x)

Signed distance from query points to the cylinder surface.

Computes the perpendicular distance from each query point to the cylinder axis, then subtracts the radius.

Source code in src/brepax/primitives/cylinder.py
def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
    """Signed distance from query points to the cylinder surface.

    Computes the perpendicular distance from each query point to
    the cylinder axis, then subtracts the radius.
    """
    # Vector from axis point to query
    v = x - self.point
    # Project onto axis to get parallel component
    proj_len = jnp.sum(v * self.axis, axis=-1, keepdims=True)
    # Perpendicular component
    perp = v - proj_len * self.axis
    perp_dist = jnp.linalg.norm(perp, axis=-1)
    return perp_dist - self.radius  # type: ignore[no-any-return]

disk

2D disk primitive defined by center and radius.

Disk

Bases: Primitive

A 2D disk defined by center and radius.

Attributes:

Name Type Description
center Float[Array, 2]

Center coordinates (2,).

radius Float[Array, '']

Scalar radius (must be positive).

Source code in src/brepax/primitives/disk.py
class Disk(Primitive):
    """A 2D disk defined by center and radius.

    Attributes:
        center: Center coordinates (2,).
        radius: Scalar radius (must be positive).
    """

    center: Float[Array, "2"]
    radius: Float[Array, ""]

    def sdf(self, x: Float[Array, "... 2"]) -> Float[Array, "..."]:
        """Signed distance from query points to the disk boundary."""
        return jnp.linalg.norm(x - self.center, axis=-1) - self.radius  # type: ignore[no-any-return]

    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        return {"center": self.center, "radius": self.radius}

    def volume(self) -> Float[Array, ""]:
        """Disk area: pi * r^2."""
        return jnp.pi * self.radius**2
parameters()

Return differentiable design parameters.

Source code in src/brepax/primitives/disk.py
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    return {"center": self.center, "radius": self.radius}
sdf(x)

Signed distance from query points to the disk boundary.

Source code in src/brepax/primitives/disk.py
def sdf(self, x: Float[Array, "... 2"]) -> Float[Array, "..."]:
    """Signed distance from query points to the disk boundary."""
    return jnp.linalg.norm(x - self.center, axis=-1) - self.radius  # type: ignore[no-any-return]
volume()

Disk area: pi * r^2.

Source code in src/brepax/primitives/disk.py
def volume(self) -> Float[Array, ""]:
    """Disk area: pi * r^2."""
    return jnp.pi * self.radius**2

finite_cylinder

3D finite cylinder primitive defined by center, axis, radius, and height.

FiniteCylinder

Bases: Primitive

A finite cylinder (capped) defined by center, axis, radius, and height.

The cylinder is centered at center and extends height/2 in each direction along axis.

Attributes:

Name Type Description
center Float[Array, 3]

Center of the cylinder (3,).

axis Float[Array, 3]

Unit direction vector of the axis (3,). Must be normalized.

radius Float[Array, '']

Cylinder radius (must be positive).

height Float[Array, '']

Total height along the axis (must be positive).

Source code in src/brepax/primitives/finite_cylinder.py
class FiniteCylinder(Primitive):
    """A finite cylinder (capped) defined by center, axis, radius, and height.

    The cylinder is centered at `center` and extends `height/2` in each
    direction along `axis`.

    Attributes:
        center: Center of the cylinder (3,).
        axis: Unit direction vector of the axis (3,). Must be normalized.
        radius: Cylinder radius (must be positive).
        height: Total height along the axis (must be positive).
    """

    center: Float[Array, "3"]
    axis: Float[Array, "3"]
    radius: Float[Array, ""]
    height: Float[Array, ""]

    def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
        """Signed distance from query points to the finite cylinder surface."""
        v = x - self.center
        # Axial distance from center
        h = jnp.sum(v * self.axis, axis=-1)
        # Perpendicular distance from axis
        perp_sq = jnp.sum(v * v, axis=-1) - h**2
        perp = jnp.sqrt(jnp.maximum(perp_sq, 1e-20))

        # Signed distances to the two constraints
        d_radial = perp - self.radius
        d_axial = jnp.abs(h) - self.height / 2.0

        # SDF of intersection of infinite cylinder and slab.
        # Use sqrt with eps to avoid NaN gradient at [0, 0].
        dr = jnp.maximum(d_radial, 0.0)
        da = jnp.maximum(d_axial, 0.0)
        outside = jnp.sqrt(dr**2 + da**2 + 1e-20)
        # Subtract the eps contribution when truly outside
        outside = jnp.where((dr > 0) | (da > 0), outside, 0.0)
        inside = jnp.minimum(jnp.maximum(d_radial, d_axial), 0.0)
        return outside + inside

    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        return {
            "center": self.center,
            "axis": self.axis,
            "radius": self.radius,
            "height": self.height,
        }

    def volume(self) -> Float[Array, ""]:
        """Finite cylinder volume: pi * r^2 * h."""
        return jnp.pi * self.radius**2 * self.height
parameters()

Return differentiable design parameters.

Source code in src/brepax/primitives/finite_cylinder.py
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    return {
        "center": self.center,
        "axis": self.axis,
        "radius": self.radius,
        "height": self.height,
    }
sdf(x)

Signed distance from query points to the finite cylinder surface.

Source code in src/brepax/primitives/finite_cylinder.py
def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
    """Signed distance from query points to the finite cylinder surface."""
    v = x - self.center
    # Axial distance from center
    h = jnp.sum(v * self.axis, axis=-1)
    # Perpendicular distance from axis
    perp_sq = jnp.sum(v * v, axis=-1) - h**2
    perp = jnp.sqrt(jnp.maximum(perp_sq, 1e-20))

    # Signed distances to the two constraints
    d_radial = perp - self.radius
    d_axial = jnp.abs(h) - self.height / 2.0

    # SDF of intersection of infinite cylinder and slab.
    # Use sqrt with eps to avoid NaN gradient at [0, 0].
    dr = jnp.maximum(d_radial, 0.0)
    da = jnp.maximum(d_axial, 0.0)
    outside = jnp.sqrt(dr**2 + da**2 + 1e-20)
    # Subtract the eps contribution when truly outside
    outside = jnp.where((dr > 0) | (da > 0), outside, 0.0)
    inside = jnp.minimum(jnp.maximum(d_radial, d_axial), 0.0)
    return outside + inside
volume()

Finite cylinder volume: pi * r^2 * h.

Source code in src/brepax/primitives/finite_cylinder.py
def volume(self) -> Float[Array, ""]:
    """Finite cylinder volume: pi * r^2 * h."""
    return jnp.pi * self.radius**2 * self.height

foot

Analytical foot-of-perpendicular on primitive surfaces.

Each function returns the closest point on an infinite analytical surface (plane, sphere, cylinder, cone, torus) to a query point. Closed form given the primitive parameters; fully differentiable w.r.t. both the query and the primitive parameters.

Scope is the untrimmed surface only; trim-boundary interaction is layered on via Marschner composition in a separate module. Cone and torus inner-region degeneracies (apex, tube axis) are handled with an epsilon guard on the radial direction.

Zero-denominator handling uses a safe-square-then-sqrt pattern that keeps both forward and gradient finite at the degenerate boundary. jnp.linalg.norm(v) has infinite derivative at v = 0; since jnp.where evaluates both branches in the VJP, a naive guard on norm still leaks NaN into jax.grad. The pattern used here switches on the squared norm before sqrt, so the sqrt argument is bounded below by 1 when the query is degenerate:

sq = sum(v * v)
is_ok = sq > eps_sq
safe_sq = where(is_ok, sq, 1.0)
norm = sqrt(safe_sq)   # always >= 1 at the degenerate point
direction = where(is_ok, v / norm, fallback)

foot_on_cone(query, apex, axis, angle)

Closest point on the half-cone extending forward of apex.

The cone is the surface r(h) = h * tan(angle) for h >= 0 along axis. For queries that project to h < 0 the foot is clamped to apex.

Source code in src/brepax/primitives/foot.py
def foot_on_cone(
    query: Float[Array, 3],
    apex: Float[Array, 3],
    axis: Float[Array, 3],
    angle: Float[Array, ""],
) -> Float[Array, 3]:
    """Closest point on the half-cone extending forward of ``apex``.

    The cone is the surface ``r(h) = h * tan(angle)`` for ``h >= 0``
    along ``axis``.  For queries that project to ``h < 0`` the foot is
    clamped to ``apex``.
    """
    v = query - apex
    h = jnp.dot(v, axis)
    radial = v - h * axis
    radial_norm, radial_dir = _safe_unit(radial, _axis_orthogonal(axis))
    cos_a = jnp.cos(angle)
    sin_a = jnp.sin(angle)
    t = h * cos_a + radial_norm * sin_a
    t = jnp.maximum(t, 0.0)
    return apex + t * (cos_a * axis + sin_a * radial_dir)

foot_on_cylinder(query, point, axis, radius)

Closest point on an infinite cylinder of radius about the axis.

On the axis itself the radial direction is ill-defined; a canonical unit vector orthogonal to axis is used to keep the computation differentiable.

Source code in src/brepax/primitives/foot.py
def foot_on_cylinder(
    query: Float[Array, 3],
    point: Float[Array, 3],
    axis: Float[Array, 3],
    radius: Float[Array, ""],
) -> Float[Array, 3]:
    """Closest point on an infinite cylinder of ``radius`` about the axis.

    On the axis itself the radial direction is ill-defined; a canonical
    unit vector orthogonal to ``axis`` is used to keep the computation
    differentiable.
    """
    v = query - point
    axial_len = jnp.dot(v, axis)
    radial = v - axial_len * axis
    _, direction = _safe_unit(radial, _axis_orthogonal(axis))
    return point + axial_len * axis + radius * direction

foot_on_plane(query, normal, offset)

Closest point on an infinite plane normal . x = offset.

Source code in src/brepax/primitives/foot.py
def foot_on_plane(
    query: Float[Array, 3],
    normal: Float[Array, 3],
    offset: Float[Array, ""],
) -> Float[Array, 3]:
    """Closest point on an infinite plane ``normal . x = offset``."""
    signed = jnp.dot(query, normal) - offset
    return query - signed * normal

foot_on_sphere(query, center, radius)

Closest point on a sphere of radius around center.

At the center (degenerate) the result is center + radius * e_z; an arbitrary direction is required to keep the gradient finite.

Source code in src/brepax/primitives/foot.py
def foot_on_sphere(
    query: Float[Array, 3],
    center: Float[Array, 3],
    radius: Float[Array, ""],
) -> Float[Array, 3]:
    """Closest point on a sphere of ``radius`` around ``center``.

    At the center (degenerate) the result is ``center + radius * e_z``;
    an arbitrary direction is required to keep the gradient finite.
    """
    v = query - center
    _, direction = _safe_unit(v, jnp.array([0.0, 0.0, 1.0]))
    return center + radius * direction

foot_on_torus(query, center, axis, major_radius, minor_radius)

Closest point on the torus tube around the major ring.

Project to the tube-center ring first, then shift outward by minor_radius. On the central axis (radial == 0) a canonical in-plane direction is used; on the tube center itself (dq == 0) the fallback is the radial direction.

Source code in src/brepax/primitives/foot.py
def foot_on_torus(
    query: Float[Array, 3],
    center: Float[Array, 3],
    axis: Float[Array, 3],
    major_radius: Float[Array, ""],
    minor_radius: Float[Array, ""],
) -> Float[Array, 3]:
    """Closest point on the torus tube around the major ring.

    Project to the tube-center ring first, then shift outward by
    ``minor_radius``.  On the central axis (``radial == 0``) a canonical
    in-plane direction is used; on the tube center itself
    (``dq == 0``) the fallback is the radial direction.
    """
    v = query - center
    h = jnp.dot(v, axis)
    radial = v - h * axis
    _, radial_dir = _safe_unit(radial, _axis_orthogonal(axis))
    tube_center = center + major_radius * radial_dir
    dq = query - tube_center
    _, dq_dir = _safe_unit(dq, radial_dir)
    return tube_center + minor_radius * dq_dir

plane

3D plane primitive defined by normal vector and offset.

Plane

Bases: Primitive

An infinite plane (half-space boundary) defined by normal and offset.

The SDF is positive on the side the normal points toward (outside), and negative on the opposite side (inside).

Attributes:

Name Type Description
normal Float[Array, 3]

Unit normal vector (3,). Must be normalized.

offset Float[Array, '']

Signed distance from origin to the plane along the normal.

Source code in src/brepax/primitives/plane.py
class Plane(Primitive):
    """An infinite plane (half-space boundary) defined by normal and offset.

    The SDF is positive on the side the normal points toward (outside),
    and negative on the opposite side (inside).

    Attributes:
        normal: Unit normal vector (3,). Must be normalized.
        offset: Signed distance from origin to the plane along the normal.
    """

    normal: Float[Array, "3"]
    offset: Float[Array, ""]

    def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
        """Signed distance from query points to the plane."""
        return jnp.sum(x * self.normal, axis=-1) - self.offset

    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        return {"normal": self.normal, "offset": self.offset}
parameters()

Return differentiable design parameters.

Source code in src/brepax/primitives/plane.py
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    return {"normal": self.normal, "offset": self.offset}
sdf(x)

Signed distance from query points to the plane.

Source code in src/brepax/primitives/plane.py
def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
    """Signed distance from query points to the plane."""
    return jnp.sum(x * self.normal, axis=-1) - self.offset

sphere

3D sphere primitive defined by center and radius.

Sphere

Bases: Primitive

A 3D sphere defined by center and radius.

Attributes:

Name Type Description
center Float[Array, 3]

Center coordinates (3,).

radius Float[Array, '']

Scalar radius (must be positive).

Source code in src/brepax/primitives/sphere.py
class Sphere(Primitive):
    """A 3D sphere defined by center and radius.

    Attributes:
        center: Center coordinates (3,).
        radius: Scalar radius (must be positive).
    """

    center: Float[Array, "3"]
    radius: Float[Array, ""]

    def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
        """Signed distance from query points to the sphere surface."""
        return jnp.linalg.norm(x - self.center, axis=-1) - self.radius  # type: ignore[no-any-return]

    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        return {"center": self.center, "radius": self.radius}

    def volume(self) -> Float[Array, ""]:
        """Sphere volume: 4/3 * pi * r^3."""
        return (4.0 / 3.0) * jnp.pi * self.radius**3
parameters()

Return differentiable design parameters.

Source code in src/brepax/primitives/sphere.py
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    return {"center": self.center, "radius": self.radius}
sdf(x)

Signed distance from query points to the sphere surface.

Source code in src/brepax/primitives/sphere.py
def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
    """Signed distance from query points to the sphere surface."""
    return jnp.linalg.norm(x - self.center, axis=-1) - self.radius  # type: ignore[no-any-return]
volume()

Sphere volume: 4/3 * pi * r^3.

Source code in src/brepax/primitives/sphere.py
def volume(self) -> Float[Array, ""]:
    """Sphere volume: 4/3 * pi * r^3."""
    return (4.0 / 3.0) * jnp.pi * self.radius**3

torus

3D torus primitive defined by center, axis, major radius, and minor radius.

Torus

Bases: Primitive

A torus defined by center, axis, major radius, and minor radius.

Attributes:

Name Type Description
center Float[Array, 3]

Center of the torus (3,).

axis Float[Array, 3]

Unit normal to the torus plane (3,). Must be normalized.

major_radius Float[Array, '']

Distance from center to tube center.

minor_radius Float[Array, '']

Tube radius.

Source code in src/brepax/primitives/torus.py
class Torus(Primitive):
    """A torus defined by center, axis, major radius, and minor radius.

    Attributes:
        center: Center of the torus (3,).
        axis: Unit normal to the torus plane (3,). Must be normalized.
        major_radius: Distance from center to tube center.
        minor_radius: Tube radius.
    """

    center: Float[Array, "3"]
    axis: Float[Array, "3"]
    major_radius: Float[Array, ""]
    minor_radius: Float[Array, ""]

    def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
        """Signed distance from query points to the torus surface."""
        v = x - self.center
        # Component along axis (height above torus plane)
        h = jnp.sum(v * self.axis, axis=-1)
        # Component in the torus plane (distance from center axis)
        v_plane_sq = jnp.sum(v * v, axis=-1) - h**2
        v_plane = jnp.sqrt(jnp.maximum(v_plane_sq, 1e-20))
        # Distance from the tube center ring
        q = jnp.sqrt((v_plane - self.major_radius) ** 2 + h**2)
        return q - self.minor_radius

    def parameters(self) -> dict[str, Array]:
        """Return differentiable design parameters."""
        return {
            "center": self.center,
            "major_radius": self.major_radius,
            "minor_radius": self.minor_radius,
        }

    def volume(self) -> Float[Array, ""]:
        """Torus volume: 2 * pi^2 * R * r^2."""
        return 2.0 * jnp.pi**2 * self.major_radius * self.minor_radius**2
parameters()

Return differentiable design parameters.

Source code in src/brepax/primitives/torus.py
def parameters(self) -> dict[str, Array]:
    """Return differentiable design parameters."""
    return {
        "center": self.center,
        "major_radius": self.major_radius,
        "minor_radius": self.minor_radius,
    }
sdf(x)

Signed distance from query points to the torus surface.

Source code in src/brepax/primitives/torus.py
def sdf(self, x: Float[Array, "... 3"]) -> Float[Array, "..."]:
    """Signed distance from query points to the torus surface."""
    v = x - self.center
    # Component along axis (height above torus plane)
    h = jnp.sum(v * self.axis, axis=-1)
    # Component in the torus plane (distance from center axis)
    v_plane_sq = jnp.sum(v * v, axis=-1) - h**2
    v_plane = jnp.sqrt(jnp.maximum(v_plane_sq, 1e-20))
    # Distance from the tube center ring
    q = jnp.sqrt((v_plane - self.major_radius) ** 2 + h**2)
    return q - self.minor_radius
volume()

Torus volume: 2 * pi^2 * R * r^2.

Source code in src/brepax/primitives/torus.py
def volume(self) -> Float[Array, ""]:
    """Torus volume: 2 * pi^2 * R * r^2."""
    return 2.0 * jnp.pi**2 * self.major_radius * self.minor_radius**2