Skip to content

Commit e0f304d

Browse files
committed
Generic Tensor
* Support type declarations Tensor[T] * Add type hints to wrap(), tensor(), layout(), linspace(), arange(), *range(), unstack()
1 parent e48df33 commit e0f304d

File tree

3 files changed

+47
-45
lines changed

3 files changed

+47
-45
lines changed

phiml/math/_magic_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from ..backend._dtype import DType
1616

1717

18+
MagicType = TypeVar('MagicType')
19+
OtherMagicType = TypeVar('OtherMagicType')
20+
1821
PhiTreeNodeType = TypeVar('PhiTreeNodeType') # Defined in phiml.math.magic: tuple, list, dict, custom
1922

2023

@@ -71,7 +74,7 @@ def slice_(value: PhiTreeNodeType, slices: Union[Dict[str, Union[int, slice, str
7174
raise ValueError(f"value must be a PhiTreeNode but got {type(value)}")
7275

7376

74-
def unstack(value, dim: DimFilter, expand=False) -> tuple:
77+
def unstack(value: MagicType, dim: DimFilter, expand=False) -> Tuple[MagicType, ...]:
7578
"""
7679
Un-stacks a `Sliceable` along one or multiple dimensions.
7780
@@ -989,9 +992,6 @@ def replace(obj: PhiTreeNodeType, **updates) -> PhiTreeNodeType:
989992

990993
# Other Ops
991994

992-
MagicType = TypeVar('MagicType')
993-
OtherMagicType = TypeVar('OtherMagicType')
994-
995995

996996
def cast(x: MagicType, dtype: Union[DType, type]) -> OtherMagicType:
997997
"""

phiml/math/_ops.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ def meshgrid(dims: Union[Callable, Shape] = spatial, stack_dim: Union[Shape, str
731731
return stack_tensors(channels, stack_dim)
732732

733733

734-
def linspace(start: Union[float, Tensor, tuple, list], stop: Union[float, Tensor, tuple, list], dim: Shape) -> Tensor:
734+
def linspace(start: Union[float, Tensor, tuple, list], stop: Union[float, Tensor, tuple, list], dim: Shape) -> Tensor[float]:
735735
"""
736736
Returns `number` evenly spaced numbers between `start` and `stop` along `dim`.
737737
@@ -772,7 +772,7 @@ def linspace(start: Union[float, Tensor, tuple, list], stop: Union[float, Tensor
772772
return map_(linspace, start, stop, dim=dim)
773773

774774

775-
def arange(dim: Shape, start_or_stop: Union[int, None] = None, stop: Union[int, None] = None, step=1, backend=None):
775+
def arange(dim: Shape, start_or_stop: Union[int, None] = None, stop: Union[int, None] = None, step=1, backend=None) -> Tensor[int]:
776776
"""
777777
Returns evenly spaced values between `start` and `stop`.
778778
If only one limit is given, `0` is used for the start.
@@ -819,7 +819,7 @@ def batched_range(dims: Shape, start: Tensor, stop: Tensor, step: Tensor):
819819
return batched_range(dim, start, stop, step)
820820

821821

822-
def range_tensor(*shape: Shape):
822+
def range_tensor(*shape: Shape) -> Tensor[int]:
823823
"""
824824
Returns a `Tensor` with given `shape` containing the linear indices of each element.
825825
For 1D tensors, this equivalent to `arange()` with `step=1`.
@@ -838,31 +838,31 @@ def range_tensor(*shape: Shape):
838838
return unpack_dim(data, 'range', shape)
839839

840840

841-
def brange(start: int = 0, **stop: int):
841+
def brange(start: int = 0, **stop: int) -> Tensor[int]:
842842
""" Construct a range `Tensor` along one batch dim. """
843843
assert len(stop) == 1, f"brange() requires exactly one stop dimension but got {stop}"
844844
return arange(batch(next(iter(stop))), start, next(iter(stop.values())))
845845

846846

847-
def drange(start: int = 0, **stop: int):
847+
def drange(start: int = 0, **stop: int) -> Tensor[int]:
848848
""" Construct a range `Tensor` along one dual dim. """
849849
assert len(stop) == 1, f"drange() requires exactly one stop dimension but got {stop}"
850850
return arange(dual(next(iter(stop))), start, next(iter(stop.values())))
851851

852852

853-
def irange(start: int = 0, **stop: int):
853+
def irange(start: int = 0, **stop: int) -> Tensor[int]:
854854
""" Construct a range `Tensor` along one instance dim. """
855855
assert len(stop) == 1, f"irange() requires exactly one stop dimension but got {stop}"
856856
return arange(instance(next(iter(stop))), start, next(iter(stop.values())))
857857

858858

859-
def srange(start: int = 0, **stop: int):
859+
def srange(start: int = 0, **stop: int) -> Tensor[int]:
860860
""" Construct a range `Tensor` along one spatial dim. """
861861
assert len(stop) == 1, f"srange() requires exactly one stop dimension but got {stop}"
862862
return arange(spatial(next(iter(stop))), start, next(iter(stop.values())))
863863

864864

865-
def crange(start: int = 0, **stop: int):
865+
def crange(start: int = 0, **stop: int) -> Tensor[int]:
866866
""" Construct a range `Tensor` along one channel dim. """
867867
assert len(stop) == 1, f"crange() requires exactly one stop dimension but got {stop}"
868868
return arange(channel(next(iter(stop))), start, next(iter(stop.values())))

phiml/math/_tensors.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import traceback
66
import warnings
77
from contextlib import contextmanager
8-
from typing import Union, TypeVar, Sequence, Any, Literal
8+
from typing import Union, TypeVar, Sequence, Any, Literal, Generic
99

1010
from dataclasses import dataclass
1111
from typing import Tuple, Callable, List
@@ -28,7 +28,9 @@
2828
from ..backend.xops import ExtraOperator
2929

3030

31-
class Tensor:
31+
T = TypeVar('T') # data type of tensors
32+
33+
class Tensor(Generic[T]):
3234
"""
3335
Abstract base class to represent structured data of one data type.
3436
This class replaces the native tensor classes `numpy.ndarray`, `torch.Tensor`, `tensorflow.Tensor` or `jax.numpy.ndarray` as the main data container in Φ-ML.
@@ -470,13 +472,13 @@ def _getitem(self, selection: dict) -> 'Tensor':
470472
def __setitem__(self, key, value):
471473
raise SyntaxError("Tensors are not editable to preserve the autodiff chain. This feature might be added in the future. To update part of a tensor, use math.where() or math.scatter()")
472474

473-
def __unstack__(self, dims: Tuple[str, ...]) -> Tuple['Tensor', ...]: # from phiml.math.magic.Sliceable
475+
def __unstack__(self, dims: Tuple[str, ...]) -> Tuple['Tensor[T]', ...]: # from phiml.math.magic.Sliceable
474476
if len(dims) == 1:
475477
return self._unstack(dims[0])
476478
else:
477479
return NotImplemented
478480

479-
def _unstack(self, dim: str):
481+
def _unstack(self, dim: str) -> Tuple['Tensor[T]', ...]:
480482
"""
481483
Splits this tensor along the specified dimension.
482484
The returned tensors have the same dimensions as this tensor save the unstacked dimension.
@@ -568,12 +570,12 @@ def dimension(self, name: Union[str, Shape]) -> 'TensorDim':
568570
else:
569571
raise ValueError(name)
570572

571-
def pack(self, dims, packed_dim):
573+
def pack(self, dims, packed_dim) -> 'Tensor[T]':
572574
""" See `pack_dims()` """
573575
from ._ops import pack_dims
574576
return pack_dims(self, dims, packed_dim)
575577

576-
def unpack(self, dim, unpacked_dims):
578+
def unpack(self, dim, unpacked_dims) -> 'Tensor[T]':
577579
""" See `unpack_dim()` """
578580
from ._ops import unpack_dim
579581
return unpack_dim(self, dim, unpacked_dims)
@@ -584,15 +586,15 @@ def T(self):
584586
# return self._with_shape_replaced(self.shape.transposed())
585587

586588
@property
587-
def Ti(self):
589+
def Ti(self) -> 'Tensor[T]':
588590
return self._with_shape_replaced(self.shape.transpose(INSTANCE_DIM))
589591

590592
@property
591-
def Tc(self):
593+
def Tc(self) -> 'Tensor[T]':
592594
return self._with_shape_replaced(self.shape.transpose(CHANNEL_DIM))
593595

594596
@property
595-
def Ts(self):
597+
def Ts(self) -> 'Tensor[T]':
596598
return self._with_shape_replaced(self.shape.transpose(SPATIAL_DIM))
597599

598600
def map(self, function: Callable, dims=shape_, range=range, unwrap_scalars=True, **kwargs):
@@ -676,7 +678,7 @@ def __mod__(self, other):
676678
def __rmod__(self, other):
677679
return self._op2(other, operator.mod, True)
678680

679-
def __eq__(self, other) -> 'Tensor':
681+
def __eq__(self, other) -> 'Tensor[bool]':
680682
if self is other:
681683
return expand(True, self.shape)
682684
if _EQUALITY_REDUCE[-1]['type'] == 'ref':
@@ -693,7 +695,7 @@ def __eq__(self, other) -> 'Tensor':
693695
else:
694696
return wrap(False)
695697

696-
def __ne__(self, other) -> 'Tensor':
698+
def __ne__(self, other) -> 'Tensor[bool]':
697699
if _EQUALITY_REDUCE[-1]['type'] == 'ref':
698700
return wrap(self is not other)
699701
elif _EQUALITY_REDUCE[-1]['type'] == 'shape_and_value':
@@ -708,49 +710,49 @@ def __ne__(self, other) -> 'Tensor':
708710
else:
709711
return wrap(True)
710712

711-
def __lt__(self, other):
713+
def __lt__(self, other) -> 'Tensor[bool]':
712714
return self._op2(other, operator.gt, True)
713715

714-
def __le__(self, other):
716+
def __le__(self, other) -> 'Tensor[bool]':
715717
return self._op2(other, operator.ge, True)
716718

717-
def __gt__(self, other):
719+
def __gt__(self, other) -> 'Tensor[bool]':
718720
return self._op2(other, operator.gt, False)
719721

720-
def __ge__(self, other):
722+
def __ge__(self, other) -> 'Tensor[bool]':
721723
return self._op2(other, operator.ge, False)
722724

723-
def __lshift__(self, other):
725+
def __lshift__(self, other) -> 'Tensor[T]':
724726
return self._op2(other, operator.lshift, False)
725727

726-
def __rlshift__(self, other):
728+
def __rlshift__(self, other) -> 'Tensor[T]':
727729
return self._op2(other, operator.lshift, True)
728730

729-
def __rshift__(self, other):
731+
def __rshift__(self, other) -> 'Tensor[T]':
730732
return self._op2(other, operator.rshift, False)
731733

732-
def __rrshift__(self, other):
734+
def __rrshift__(self, other) -> 'Tensor[T]':
733735
return self._op2(other, operator.rshift, True)
734736

735-
def __abs__(self):
737+
def __abs__(self) -> 'Tensor[T]':
736738
return self._op1(lambda t: choose_backend(t).abs(t))
737739

738-
def __round__(self, n=None):
740+
def __round__(self, n=None) -> 'Tensor[int]':
739741
return self._op1(lambda t: choose_backend(t).round(t))
740742

741-
def __copy__(self):
743+
def __copy__(self) -> 'Tensor[T]':
742744
return self._op1(lambda t: choose_backend(t).copy(t, only_mutable=True))
743745

744-
def __deepcopy__(self, memodict={}):
746+
def __deepcopy__(self, memodict={}) -> 'Tensor[T]':
745747
return self._op1(lambda t: choose_backend(t).copy(t, only_mutable=False))
746748

747-
def __neg__(self) -> 'Tensor':
749+
def __neg__(self) -> 'Tensor[T]':
748750
return self._op1(operator.neg)
749751

750-
def __invert__(self) -> 'Tensor':
752+
def __invert__(self) -> 'Tensor[T]':
751753
return self._op1(lambda t: choose_backend(t).invert(t))
752754

753-
def __reversed__(self):
755+
def __reversed__(self) -> 'Tensor[T]':
754756
assert self.shape.channel.rank == 1
755757
return self[::-1]
756758

@@ -763,11 +765,11 @@ def __iter__(self):
763765
native = self.native([self.shape])
764766
return iter(native)
765767

766-
def item(self):
768+
def item(self) -> T:
767769
assert self.shape.volume == 1, f"Tensor.item() is only available for single-element Tensors but got {self.shape}"
768770
return next(iter(self))
769771

770-
def __matmul__(self, other):
772+
def __matmul__(self, other) -> 'Tensor[bool]':
771773
from ._ops import dot
772774
assert isinstance(other, Tensor), f"Matmul '@' requires two Tensor arguments but got {type(other)}"
773775
if not self.shape.dual_rank and self.shape.channel_rank:
@@ -953,7 +955,7 @@ def __init__(self, obj, stack_dim: Shape):
953955
warnings.warn(f"Empty stack_dim for Layout with value {obj}")
954956

955957
@staticmethod
956-
def _recursive_get_shapes(obj, s: Shape) -> Tuple[Shape]:
958+
def _recursive_get_shapes(obj, s: Shape) -> Tuple[Shape, ...]:
957959
if not s:
958960
return shape(obj, allow_unshaped=True),
959961
elif isinstance(obj, (tuple, list)):
@@ -1670,10 +1672,10 @@ def _simplify(self):
16701672
return self
16711673

16721674

1673-
def tensor(data,
1675+
def tensor(data: Union[Sequence[T], T],
16741676
*shape: Union[Shape, str, list],
16751677
convert: bool = True,
1676-
default_list_dim=channel('vector')) -> Tensor: # TODO assume convert_unsupported, add convert_external=False for constants
1678+
default_list_dim=channel('vector')) -> Tensor[T]: # TODO assume convert_unsupported, add convert_external=False for constants
16771679
"""
16781680
Create a Tensor from the specified `data`.
16791681
If `convert=True`, converts `data` to the preferred format of the default backend.
@@ -1814,12 +1816,12 @@ def tensor(data,
18141816
raise ValueError(f"{type(data)} is not supported. Only (Tensor, tuple, list, np.ndarray, native tensors) are allowed.\nCurrent backends: {BACKENDS}")
18151817

18161818

1817-
def wrap(data, *shape: Union[Shape, str, list], default_list_dim=channel('vector')) -> Tensor:
1819+
def wrap(data: Union[Sequence[T], T], *shape: Union[Shape, str, list], default_list_dim=channel('vector')) -> Tensor[T]:
18181820
""" Short for `phiml.math.tensor()` with `convert=False`. """
18191821
return tensor(data, *shape, convert=False, default_list_dim=default_list_dim)
18201822

18211823

1822-
def layout(objects, *shape: Union[Shape, str]) -> Tensor:
1824+
def layout(objects: Union[Sequence[T], T], *shape: Union[Shape, str]) -> Tensor[T]:
18231825
"""
18241826
Wraps a Python tree in a `Tensor`, allowing elements to be accessed via dimensions.
18251827
A python tree is a structure of nested `tuple`, `list`, `dict` and *leaf* objects where leaves can be any Python object.

0 commit comments

Comments
 (0)