summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/__init__.py90
1 files changed, 53 insertions, 37 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index 23fe913..64f3d63 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -7,10 +7,11 @@ import numpy as np
import scipy.sparse
import sympy
-from polymatrix.expression.expression import Expression
+from polymatrix.expression.mixins.expressionmixin import ExpressionMixin as Expression
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
-from polymatrix.expression.mixins.expressionmixin import ExpressionMixin
+# from polymatrix.expression.mixins.expressionmixin import ExpressionMixin
from polymatrix.expression.mixins.parametrizeexprmixin import ParametrizeExprMixin
+from polymatrix.expression.mixins.parametrizematrixexprmixin import ParametrizeMatrixExprMixin
from polymatrix.expressionstate.expressionstate import ExpressionState
from polymatrix.expression.init.initblockdiagexpr import init_block_diag_expr
from polymatrix.expression.init.initexpression import init_expression
@@ -19,7 +20,7 @@ from polymatrix.expression.init.initfromsympyexpr import init_from_sympy_expr
from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
from polymatrix.expression.init.initvstackexpr import init_v_stack_expr
from polymatrix.polymatrix.polymatrix import PolyMatrix
-from polymatrix.expression.utils.getvariableindices import get_variable_indices, get_variable_indices_from_variable
+from polymatrix.expression.utils.getvariableindices import get_variable_indices_from_variable
from polymatrix.statemonad.init.initstatemonad import init_state_monad
from polymatrix.statemonad.mixins.statemonadmixin import StateMonadMixin
from polymatrix.expression.utils.monomialtoindex import monomial_to_index
@@ -52,8 +53,11 @@ def from_polymatrix(
)
def from_(
- data: tuple[tuple[float]],
+ data: str | tuple[tuple[float]],
):
+ if isinstance(data, str):
+ return from_(1).parametrize(data)
+
return from_sympy(data)
def v_stack(
@@ -308,7 +312,7 @@ def shape(
@dataclasses.dataclass
class MatrixBuffer:
- data: dict[int, ...]
+ data: dict[int, np.ndarray]
n_row: int
n_param: int
@@ -394,6 +398,9 @@ class MatrixRepresentations:
if isinstance(x, tuple) or isinstance(x, list):
x = np.array(x).reshape(-1, 1)
+ elif x.shape[0] == 1:
+ x = x.reshape(-1, 1)
+
def acc_x_powers(acc, _):
next = (acc @ x.T).reshape(-1, 1)
return next
@@ -433,14 +440,14 @@ class MatrixRepresentations:
def to_matrix_repr(
expressions: Expression | tuple[Expression],
- variables: Expression,
+ variables: Expression = None,
sorted: bool = None,
) -> StateMonadMixin[ExpressionState, tuple[tuple[tuple[np.ndarray, ...], ...], tuple[int, ...]]]:
if isinstance(expressions, Expression):
expressions = (expressions,)
- assert isinstance(variables, ExpressionBaseMixin), f'{variables=}'
+ assert isinstance(variables, ExpressionBaseMixin) or variables is None, f'{variables=}'
def func(state: ExpressionState):
@@ -453,32 +460,36 @@ def to_matrix_repr(
return state, underlying_list + (underlying,)
- *_, (state, underlying_list) = tuple(itertools.accumulate(
+ *_, (state, polymatrix_list) = tuple(itertools.accumulate(
expressions,
acc_underlying_application,
initial=(state, tuple()),
))
- state, variable_index = get_variable_indices_from_variable(state, variables)
-
- if sorted:
- tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_index)
-
- sorted_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1]))
+ if variables is None:
+ sorted_variable_index = tuple()
else:
- sorted_variable_index = variable_index
+ state, variable_index = get_variable_indices_from_variable(state, variables)
- assert len(sorted_variable_index) == len(set(sorted_variable_index)), f'{sorted_variable_index=} contains repeated variables'
+ if sorted:
+ tagged_variable_index = tuple((offset, state.get_name_from_offset(offset)) for offset in variable_index)
+
+ sorted_variable_index = tuple(v[0] for v in sorted(tagged_variable_index, key=lambda v: v[1]))
+
+ else:
+ sorted_variable_index = variable_index
+
+ assert len(sorted_variable_index) == len(set(sorted_variable_index)), f'{sorted_variable_index=} contains repeated variables. Make sure you give a unique name for each variables.'
variable_index_map = {old: new for new, old in enumerate(sorted_variable_index)}
n_param = len(sorted_variable_index)
- def gen_underlying_matrices():
- for underlying in underlying_list:
+ def gen_numpy_matrices():
+ for polymatrix in polymatrix_list:
- n_row = underlying.shape[0]
+ n_row = polymatrix.shape[0]
buffer = MatrixBuffer(
data={},
@@ -487,34 +498,39 @@ def to_matrix_repr(
)
for row in range(n_row):
- underlying_terms = underlying.get_poly(row, 0)
- if underlying_terms is None:
+ polymatrix_terms = polymatrix.get_poly(row, 0)
+
+ if polymatrix_terms is None:
continue
+
+ if len(polymatrix_terms) == 0:
+ buffer.add(row, 0, 0, 0)
- for monomial, value in underlying_terms.items():
+ else:
+ for monomial, value in polymatrix_terms.items():
- def gen_new_monomial():
- for var, count in monomial:
- try:
- index = variable_index_map[var]
- except KeyError:
- raise KeyError(f'{var=} ({state.get_key_from_offset(var)}) is incompatible with {variable_index_map=}')
+ def gen_new_monomial():
+ for var, count in monomial:
+ try:
+ index = variable_index_map[var]
+ except KeyError:
+ raise KeyError(f'{var=} ({state.get_key_from_offset(var)}) is incompatible with {variable_index_map=}')
- for _ in range(count):
- yield index
+ for _ in range(count):
+ yield index
- new_monomial = tuple(gen_new_monomial())
+ new_monomial = tuple(gen_new_monomial())
- cols = monomial_to_index(n_param, new_monomial)
+ cols = monomial_to_index(n_param, new_monomial)
- col_value = value / len(cols)
+ col_value = value / len(cols)
- for col in cols:
- buffer.add(row, col, sum(count for _, count in monomial), col_value)
+ for col in cols:
+ buffer.add(row, col, sum(count for _, count in monomial), col_value)
yield buffer
- underlying_matrices = tuple(gen_underlying_matrices())
+ underlying_matrices = tuple(gen_numpy_matrices())
def gen_auxillary_equations():
for key, monomial_terms in state.auxillary_equations.items():
@@ -641,7 +657,7 @@ def to_sympy_repr(
if isinstance(variable, sympy.core.symbol.Symbol):
variable_name = variable.name
- elif isinstance(variable, ParametrizeExprMixin):
+ elif isinstance(variable, (ParametrizeExprMixin, ParametrizeMatrixExprMixin)):
variable_name = variable.name
elif isinstance(variable, str):
variable_name = variable