from __future__ import annotations
import dataclasses
import typing
from metadsl import *
from metadsl_rewrite import *
import metadsl.typing_tools
from .strategies import *
__all__ = ["Abstraction", "Variable"]
T = typing.TypeVar("T")
U = typing.TypeVar("U")
V = typing.TypeVar("V")
[docs]@dataclasses.dataclass(eq=False)
class Variable:
"""
should only be equal to itself
"""
def __str__(self):
return hex(id(self))
def __repr__(self):
return f"Variable: {str(self)}"
[docs]class Abstraction(Expression, typing.Generic[T, U]):
@expression
def __call__(self, arg: T) -> U:
...
[docs] @expression
@classmethod
def create(cls, var: T, body: U) -> Abstraction[T, U]:
...
[docs] @expression
@classmethod
def from_fn(cls, fn: typing.Callable[[T], U]) -> Abstraction[T, U]:
v: T = cls.create_variable(Variable())
return cls.create(v, fn(v))
[docs] @expression
@classmethod
def create_variable(cls, variable: Variable) -> T:
...
@expression
def __add__(self, other: Abstraction[V, T]) -> Abstraction[V, U]:
"""
Composes this function with another.
(f + g)(x) == f(g(x))
"""
...
[docs] @staticmethod
@expression
def fix(fn: Abstraction[T, T]) -> T:
"""
Fixed pointer operator, used to define recursive functions.
"""
return fn(Abstraction.fix(fn))
[docs] @expression # type: ignore
@property
def unfix(self) -> Abstraction[Abstraction[T, U], Abstraction[T, U]]:
"""
Extracts the inner function created from a fixed point operator.
If its not a fixed point funciton, returns a function that always returns the original.
"""
...
# Run the `from_fn` before any other rules, so that variables
# are all created before any are replaced
# This is needed if a variable is caught in an inner scope of a local function
from_fn_rule = register.pre(default_rule(Abstraction[T, U].from_fn))
# Run the fixed point operator rule after all others, so that
# it is only expanded if we need it.
fix_rule = register.post(default_rule(Abstraction.fix))
def _replace(body: U, var: T, arg: T) -> U:
"""
Replaces all instances of `var` with `arg` inside of `body`,
except for local bindings of `var` as declared in other `from_fn`s inside.
"""
if body == var:
return arg # type: ignore
if not isinstance(body, Expression):
return body
is_abstraction = (
isinstance(body.function, metadsl.typing_tools.BoundInfer)
and body.function.fn == Abstraction.create.fn # type: ignore
)
# If is a `from_fn` node with the same var bound, don't try replacing its children
if is_abstraction and body.args[0] == var:
return body # type: ignore
return body._map(lambda e: _replace(e, var, arg)) # type: ignore
@register_core
@rule
def compose(vl: T, bl: U, vr: V, br: T) -> R[Abstraction[V, U]]:
# We want to define composition for the lambda calculus
# We start with our two functions, each with a variable and a body:
# f = 𝜆vl.bl
# g = 𝜆vr.br
# We want to compute their composition:
# 𝜆x.f(g(x))
# We can start by replacing function application with replacing all instances of the variable in the body
# == f(br[vr/x])
# == bl[vl/br[vr/x]]
# Now, what we want is to pull out the `x` replacement to the outside, so we can create another function from this
# Assuming no overlapping variables names in the scope, we should be able to do the replacements in sequence:
# == bl[vl/br][vr/x]
# == 𝜆vr.bl[vl/br]
# this checks out typing wise, which is nice!
return (
Abstraction[T, U].create(vl, bl) # type: ignore
+ Abstraction[V, T].create(vr, br),
lambda: Abstraction.create(vr, _replace(bl, vl, br)),
)
@register_core # type: ignore
@rule
def beta_reduce(var: T, body: U, arg: T) -> R[U]:
return (Abstraction[T, U].create(var, body)(arg), lambda: _replace(body, var, arg))
@register_core
@rule
def unfix_normal(
var: T, body: U
) -> R[Abstraction[Abstraction[T, U], Abstraction[T, U]]]:
original = Abstraction.create(var, body)
# If this is a normal abstraction, then just return the original
return (
original.unfix,
lambda: Abstraction[Abstraction[T, U], Abstraction[T, U]].from_fn(
lambda _: original
),
)
@register_core
@rule
def unfix_fixed(
a: Abstraction[Abstraction[T, U], Abstraction[T, U]]
) -> R[Abstraction[Abstraction[T, U], Abstraction[T, U]]]:
return (
Abstraction.fix(a).unfix,
a,
)