summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--polymatrix/__init__.py45
-rw-r--r--polymatrix/expression/accumulateexpr.py4
-rw-r--r--polymatrix/expression/fromtermsexpr.py4
-rw-r--r--polymatrix/expression/impl/expressionstateimpl.py2
-rw-r--r--polymatrix/expression/impl/fromtermsexprimpl.py8
-rw-r--r--polymatrix/expression/impl/toquadraticexprimpl.py8
-rw-r--r--polymatrix/expression/init/initaccumulateexpr.py25
-rw-r--r--polymatrix/expression/init/initexpressionstate.py2
-rw-r--r--polymatrix/expression/init/initfromtermsexpr.py14
-rw-r--r--polymatrix/expression/init/inittoquadraticexpr.py10
-rw-r--r--polymatrix/expression/mixins/accumulateexprmixin.py56
-rw-r--r--polymatrix/expression/mixins/additionexprmixin.py14
-rw-r--r--polymatrix/expression/mixins/derivativeexprmixin.py503
-rw-r--r--polymatrix/expression/mixins/determinantexprmixin.py27
-rw-r--r--polymatrix/expression/mixins/divisionexprmixin.py17
-rw-r--r--polymatrix/expression/mixins/elemmultexprmixin.py10
-rw-r--r--polymatrix/expression/mixins/evalexprmixin.py14
-rw-r--r--polymatrix/expression/mixins/expressionbasemixin.py8
-rw-r--r--polymatrix/expression/mixins/expressionmixin.py36
-rw-r--r--polymatrix/expression/mixins/expressionstatemixin.py2
-rw-r--r--polymatrix/expression/mixins/forallexprmixin.py25
-rw-r--r--polymatrix/expression/mixins/fromarrayexprmixin.py15
-rw-r--r--polymatrix/expression/mixins/fromtermsexprmixin.py36
-rw-r--r--polymatrix/expression/mixins/getitemexprmixin.py10
-rw-r--r--polymatrix/expression/mixins/kktexprmixin.py29
-rw-r--r--polymatrix/expression/mixins/matrixmultexprmixin.py10
-rw-r--r--polymatrix/expression/mixins/parametrizetermsexprmixin.py84
-rw-r--r--polymatrix/expression/mixins/polymatrixasdictmixin.py8
-rw-r--r--polymatrix/expression/mixins/quadraticinexprmixin.py10
-rw-r--r--polymatrix/expression/mixins/repmatexprmixin.py10
-rw-r--r--polymatrix/expression/mixins/toquadraticexprmixin.py78
-rw-r--r--polymatrix/expression/mixins/transposeexprmixin.py10
-rw-r--r--polymatrix/expression/mixins/vstackexprmixin.py39
-rw-r--r--polymatrix/expression/toquadraticexpr.py4
34 files changed, 868 insertions, 309 deletions
diff --git a/polymatrix/__init__.py b/polymatrix/__init__.py
index 0f9bd81..3cf9f35 100644
--- a/polymatrix/__init__.py
+++ b/polymatrix/__init__.py
@@ -1,13 +1,11 @@
-# import typing
-# from polymatrix.init.initpolymatrixexpr import init_poly_matrix_expr
-# from polymatrix.init.initpolymatrixparamexpr import init_poly_matrix_param_expr
-
-
-from ast import Expression
+from polymatrix.expression.expression import Expression
+from polymatrix.expression.init.initaccumulateexpr import init_accumulate_expr
from polymatrix.expression.init.initexpression import init_expression
from polymatrix.expression.init.initfromarrayexpr import init_from_array_expr
+from polymatrix.expression.init.initfromtermsexpr import init_from_terms_expr
from polymatrix.expression.init.initkktexpr import init_kkt_expr
from polymatrix.expression.init.initvstackexpr import init_v_stack_expr
+from polymatrix.expression.polymatrix import PolyMatrix
def from_(
@@ -21,14 +19,33 @@ def from_(
init_from_array_expr(data)
)
- # return init_poly_matrix_expr(
- # underlying=init_poly_matrix_param_expr(
- # name=name,
- # degrees=degrees,
- # shape=shape,
- # re_index=re_index,
- # )
- # )
+
+def accumulate(
+ expr,
+ func,
+ initial = None,
+):
+ def lifted_func(acc, polymat: PolyMatrix):
+
+
+ # # print(f'{terms=}')
+ # print(f'{terms=}')
+
+ lifeted_val = init_expression(
+ underlying=init_from_terms_expr(
+ terms=polymat.terms,
+ shape=polymat.shape,
+ ),
+ )
+ return func(acc, lifeted_val)
+
+ return init_expression(
+ underlying=init_accumulate_expr(
+ underlying=expr.underlying,
+ acc_func=lifted_func,
+ initial=initial,
+ ),
+ )
def v_stack(
expressions: tuple[Expression],
diff --git a/polymatrix/expression/accumulateexpr.py b/polymatrix/expression/accumulateexpr.py
new file mode 100644
index 0000000..efcbc39
--- /dev/null
+++ b/polymatrix/expression/accumulateexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.accumulateexprmixin import AccumulateExprMixin
+
+class AccumulateExpr(AccumulateExprMixin):
+ pass
diff --git a/polymatrix/expression/fromtermsexpr.py b/polymatrix/expression/fromtermsexpr.py
new file mode 100644
index 0000000..2dbb5aa
--- /dev/null
+++ b/polymatrix/expression/fromtermsexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.fromtermsexprmixin import FromTermsExprMixin
+
+class FromTermsExpr(FromTermsExprMixin):
+ pass
diff --git a/polymatrix/expression/impl/expressionstateimpl.py b/polymatrix/expression/impl/expressionstateimpl.py
index 89afb56..7513970 100644
--- a/polymatrix/expression/impl/expressionstateimpl.py
+++ b/polymatrix/expression/impl/expressionstateimpl.py
@@ -7,5 +7,5 @@ from typing import Optional
class ExpressionStateImpl(ExpressionState):
n_param: int
offset_dict: dict
- auxillary_terms: tuple[dict[tuple[int], float]]
+ auxillary_equations: dict[int, dict[tuple[int], float]]
cached_polymatrix: dict
diff --git a/polymatrix/expression/impl/fromtermsexprimpl.py b/polymatrix/expression/impl/fromtermsexprimpl.py
new file mode 100644
index 0000000..92bb6a7
--- /dev/null
+++ b/polymatrix/expression/impl/fromtermsexprimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.expression.fromtermsexpr import FromTermsExpr
+
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class FromTermsExprImpl(FromTermsExpr):
+ terms: tuple
+ shape: tuple[int, int]
diff --git a/polymatrix/expression/impl/toquadraticexprimpl.py b/polymatrix/expression/impl/toquadraticexprimpl.py
new file mode 100644
index 0000000..63a2cdb
--- /dev/null
+++ b/polymatrix/expression/impl/toquadraticexprimpl.py
@@ -0,0 +1,8 @@
+import dataclass_abc
+from polymatrix.expression.toquadraticexpr import ToQuadraticExpr
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+
+@dataclass_abc.dataclass_abc(frozen=True)
+class ToQuadraticExprImpl(ToQuadraticExpr):
+ underlying: ExpressionBaseMixin
diff --git a/polymatrix/expression/init/initaccumulateexpr.py b/polymatrix/expression/init/initaccumulateexpr.py
new file mode 100644
index 0000000..30297bf
--- /dev/null
+++ b/polymatrix/expression/init/initaccumulateexpr.py
@@ -0,0 +1,25 @@
+import dataclass_abc
+import typing
+
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.accumulateexpr import AccumulateExpr
+
+
+def init_accumulate_expr(
+ underlying: ExpressionBaseMixin,
+ acc_func: typing.Callable,
+ initial: ExpressionBaseMixin = None,
+):
+
+ @dataclass_abc.dataclass_abc(frozen=True)
+ class AccumulateExprImpl(AccumulateExpr):
+ underlying: ExpressionBaseMixin
+ initial: ExpressionBaseMixin
+
+ def acc_func(self, acc, v):
+ return acc_func(acc, v)
+
+ return AccumulateExprImpl(
+ underlying=underlying,
+ initial=initial,
+ )
diff --git a/polymatrix/expression/init/initexpressionstate.py b/polymatrix/expression/init/initexpressionstate.py
index a7d3aac..7e8a6fe 100644
--- a/polymatrix/expression/init/initexpressionstate.py
+++ b/polymatrix/expression/init/initexpressionstate.py
@@ -14,6 +14,6 @@ def init_expression_state(
return ExpressionStateImpl(
n_param=n_param,
offset_dict=offset_dict,
- auxillary_terms=tuple(),
+ auxillary_equations={},
cached_polymatrix={},
)
diff --git a/polymatrix/expression/init/initfromtermsexpr.py b/polymatrix/expression/init/initfromtermsexpr.py
new file mode 100644
index 0000000..80d5198
--- /dev/null
+++ b/polymatrix/expression/init/initfromtermsexpr.py
@@ -0,0 +1,14 @@
+from polymatrix.expression.impl.fromtermsexprimpl import FromTermsExprImpl
+
+
+def init_from_terms_expr(
+ terms: tuple,
+ shape: tuple[int, int]
+):
+ if isinstance(terms, dict):
+ terms = tuple((key, tuple(value.items())) for key, value in terms.items())
+
+ return FromTermsExprImpl(
+ terms=terms,
+ shape=shape,
+ )
diff --git a/polymatrix/expression/init/inittoquadraticexpr.py b/polymatrix/expression/init/inittoquadraticexpr.py
new file mode 100644
index 0000000..dfc0567
--- /dev/null
+++ b/polymatrix/expression/init/inittoquadraticexpr.py
@@ -0,0 +1,10 @@
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.impl.toquadraticexprimpl import ToQuadraticExprImpl
+
+
+def init_to_quadratic_expr(
+ underlying: ExpressionBaseMixin,
+):
+ return ToQuadraticExprImpl(
+ underlying=underlying,
+)
diff --git a/polymatrix/expression/mixins/accumulateexprmixin.py b/polymatrix/expression/mixins/accumulateexprmixin.py
new file mode 100644
index 0000000..c92051a
--- /dev/null
+++ b/polymatrix/expression/mixins/accumulateexprmixin.py
@@ -0,0 +1,56 @@
+
+import abc
+import typing
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class AccumulateExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ @abc.abstractmethod
+ def acc_func(self) -> typing.Callable:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def initial(self) -> ExpressionBaseMixin:
+ ...
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ state, underlying = self.underlying.apply(state=state)
+
+ acc = self.initial
+
+ for row in range(underlying.shape[0]):
+
+ col_terms = {}
+
+ for col in range(underlying.shape[1]):
+ try:
+ underlying_terms = underlying.get_poly(row, col)
+ except KeyError:
+ continue
+
+ col_terms[0, col] = underlying_terms
+
+ poly_matrix = init_poly_matrix(
+ terms=col_terms,
+ shape=(1, underlying.shape[1]),
+ )
+
+ acc = self.acc_func(acc, poly_matrix)
+
+ state, poly_matrix = acc.apply(state=state)
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/additionexprmixin.py b/polymatrix/expression/mixins/additionexprmixin.py
index 776e952..cab8161 100644
--- a/polymatrix/expression/mixins/additionexprmixin.py
+++ b/polymatrix/expression/mixins/additionexprmixin.py
@@ -20,10 +20,10 @@ class AddExprMixin(ExpressionBaseMixin):
def right(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
- @property
- def shape(self) -> tuple[int, int]:
- return self.left.shape
+ # # overwrites abstract method of `ExpressionBaseMixin`
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # return self.left.shape
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
@@ -39,8 +39,8 @@ class AddExprMixin(ExpressionBaseMixin):
for underlying in (left, right):
- for row in range(self.shape[0]):
- for col in range(self.shape[1]):
+ for row in range(left.shape[0]):
+ for col in range(left.shape[1]):
if (row, col) in terms:
terms_row_col = terms[row, col]
@@ -64,7 +64,7 @@ class AddExprMixin(ExpressionBaseMixin):
poly_matrix = init_poly_matrix(
terms=terms,
- shape=self.shape,
+ shape=left.shape,
)
return state, poly_matrix
diff --git a/polymatrix/expression/mixins/derivativeexprmixin.py b/polymatrix/expression/mixins/derivativeexprmixin.py
index 516aa38..6098aa9 100644
--- a/polymatrix/expression/mixins/derivativeexprmixin.py
+++ b/polymatrix/expression/mixins/derivativeexprmixin.py
@@ -1,9 +1,11 @@
import abc
import collections
+import dataclasses
+import itertools
import typing
import dataclass_abc
-from numpy import var
+from numpy import nonzero, var
from polymatrix.expression.init.initderivativekey import init_derivative_key
from polymatrix.expression.init.initpolymatrix import init_poly_matrix
@@ -28,128 +30,477 @@ class DerivativeExprMixin(ExpressionBaseMixin):
def introduce_derivatives(self) -> bool:
...
- # overwrites abstract method of `ExpressionBaseMixin`
- @property
- def shape(self) -> tuple[int, int]:
- match self.variables:
- case ExpressionBaseMixin():
- n_cols = self.variables.shape[0]
+ # # overwrites abstract method of `ExpressionBaseMixin`
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # match self.variables:
+ # case ExpressionBaseMixin():
+ # n_cols = self.variables.shape[0]
- case _:
- n_cols = len(self.variables)
+ # case _:
+ # n_cols = len(self.variables)
- return self.underlying.shape[0], n_cols
+ # return self.underlying.shape[0], n_cols
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
self,
state: PolyMatrixExprState,
) -> tuple[PolyMatrixExprState, PolyMatrix]:
- state = [state]
-
- state[0], underlying = self.underlying.apply(state=state[0])
match self.variables:
case ExpressionBaseMixin():
- assert self.variables.shape[1] == 1
+ state, variables = self.variables.apply(state)
- state[0], variables = self.variables.apply(state[0])
+ assert variables.shape[1] == 1
def gen_indices():
for row in range(variables.shape[0]):
- for monomial in variables.get_poly(row, 0).keys():
+ row_terms = variables.get_poly(row, 0)
+
+ assert len(row_terms) == 1
+
+ for monomial in row_terms.keys():
+ assert len(monomial) == 1
yield monomial[0]
- variable_indices = tuple(sorted(gen_indices()))
- # print(f'{variable_indices=}')
+ diff_wrt_variables = tuple(gen_indices())
case _:
def gen_indices():
- for variable in self.variable:
- if variable in state[0].offset_dict:
- yield state[0].offset_dict[variable][0]
+ for variable in self.variables:
+ if variable in state.offset_dict:
+ yield state.offset_dict[variable][0]
- variable_indices = tuple(sorted(gen_indices()))
+ diff_wrt_variables = tuple(gen_indices())
- terms = {}
- # aux_terms = []
+ def get_derivative_terms(
+ monomial_terms,
+ diff_wrt_variable: int,
+ state: PolyMatrixExprState,
+ considered_variables: set,
+ ):
- for var_idx, variable in enumerate(variable_indices):
-
- def get_derivative_terms(monomial_terms):
+ if self.introduce_derivatives:
+
+ def gen_new_variables():
+ for monomial in monomial_terms.keys():
+ for var in monomial:
+ if var not in diff_wrt_variables and var not in considered_variables:
+ yield var
+
+ new_variables = set(gen_new_variables())
+
+ new_considered_variables = considered_variables | new_variables
- terms_row_col = {}
+ def acc_state_candidates(acc, new_variable):
+ state, candidates = acc
- for monomial, value in monomial_terms.items():
+ key = init_derivative_key(
+ variable=new_variable,
+ with_respect_to=diff_wrt_variable,
+ )
+ state = state.register(key=key, n_param=1)
- # count powers for each variable
- monomial_cnt = dict(collections.Counter(monomial))
+ state, auxillary_derivation_terms = get_derivative_terms(
+ monomial_terms=state.auxillary_equations[new_variable],
+ diff_wrt_variable=diff_wrt_variable,
+ state=state,
+ considered_variables=new_considered_variables,
+ )
- if variable not in monomial_cnt:
- continue
+ if 1 < len(auxillary_derivation_terms):
+ derivation_variable = state.offset_dict[key][0]
+
+ state = dataclasses.replace(
+ state,
+ auxillary_equations=state.auxillary_equations | {derivation_variable: auxillary_derivation_terms},
+ )
+
+ return state, candidates + (new_variable,)
- if self.introduce_derivatives:
- variable_candidates = (variable,) + tuple(var for var in monomial_cnt.keys() if var not in variable_indices)
else:
- variable_candidates = (variable,)
-
- for variable_candidate in variable_candidates:
+ return state, candidates
- def generate_monomial():
- for current_variable, current_count in monomial_cnt.items():
+ *_, (state, confirmed_variables) = itertools.accumulate(
+ new_variables,
+ acc_state_candidates,
+ initial=(state, tuple()),
+ )
- if current_variable is variable_candidate:
- sel_counter = current_count - 1
+ else:
+ confirmed_variables = tuple()
- else:
- sel_counter = current_count
+ derivation_terms = collections.defaultdict(float)
- for _ in range(sel_counter):
- yield current_variable
+ for monomial, value in monomial_terms.items():
- if variable_candidate is not variable:
- key = init_derivative_key(
- variable=variable_candidate,
- with_respect_to=var_idx,
- )
- state[0] = state[0].register(key=key, n_param=1)
+ # count powers for each variable
+ monomial_cnt = dict(collections.Counter(monomial))
- yield state[0].offset_dict[key][0]
+ def differentiate_monomial(dependent_variable, derivation_variable=None):
+ def gen_diff_monomial():
+ for current_variable, current_count in monomial_cnt.items():
- col_monomial = tuple(generate_monomial())
+ if current_variable is dependent_variable:
+ sel_counter = current_count - 1
- if col_monomial not in terms_row_col:
- terms_row_col[col_monomial] = 0
+ else:
+ sel_counter = current_count
- terms_row_col[col_monomial] += value * monomial_cnt[variable_candidate]
+ for _ in range(sel_counter):
+ yield current_variable
- return terms_row_col
+ if derivation_variable is not None:
+ yield derivation_variable
- for row in range(self.shape[0]):
+ diff_monomial = tuple(sorted(gen_diff_monomial()))
- try:
- underlying_terms = underlying.get_poly(row, 0)
- except KeyError:
- continue
+ return diff_monomial, value * monomial_cnt[dependent_variable]
- derivative_terms = get_derivative_terms(underlying_terms)
+ if diff_wrt_variable in monomial_cnt:
+ diff_monomial, value = differentiate_monomial(diff_wrt_variable)
+ derivation_terms[diff_monomial] += value
- if 0 < len(derivative_terms):
- terms[row, var_idx] = derivative_terms
+ for candidate_variable in monomial_cnt.keys():
+ if candidate_variable in considered_variables or candidate_variable in confirmed_variables:
+ key = init_derivative_key(
+ variable=candidate_variable,
+ with_respect_to=diff_wrt_variable,
+ )
+ derivation_variable = state.offset_dict[key][0]
- # if self.introduce_derivatives:
- # for aux_monomial in underlying.aux_terms:
+ diff_monomial, value = differentiate_monomial(
+ dependent_variable=candidate_variable,
+ derivation_variable=derivation_variable,
+ )
+ derivation_terms[diff_monomial] += value
- # derivative_terms = get_derivative_terms(aux_monomial)
+ return state, dict(derivation_terms)
+
+ state, underlying = self.underlying.apply(state=state)
+
+ terms = {}
+
+ for row in range(underlying.shape[0]):
+
+ try:
+ underlying_terms = underlying.get_poly(row, 0)
+ except KeyError:
+ continue
+
+ # derivate each variable and map result to the corresponding column
+ for col, diff_wrt_variable in enumerate(diff_wrt_variables):
+
+ state, derivation_terms = get_derivative_terms(
+ monomial_terms=underlying_terms,
+ diff_wrt_variable=diff_wrt_variable,
+ state=state,
+ considered_variables=set(),
+ )
+
+ if 0 < len(derivation_terms):
+ terms[row, col] = derivation_terms
- # # todo: is this correct?
- # if 1 < len(derivative_terms):
- # aux_terms.append(derivative_terms)
-
poly_matrix = init_poly_matrix(
terms=terms,
- shape=self.shape,
- # aux_terms=underlying.aux_terms + tuple(aux_terms),
+ shape=(underlying.shape[0], len(diff_wrt_variables)),
)
- return state[0], poly_matrix
+ return state, poly_matrix
+
+ # def get_derivative_terms(
+ # monomial_terms,
+ # diff_wrt_variable: int,
+ # state: PolyMatrixExprState,
+ # considered_variables: set,
+ # # implement_derivation: bool,
+ # ):
+ # derivation_terms = collections.defaultdict(float)
+
+ # other_independent_variables = tuple(var for var in diff_wrt_variables if var is not diff_wrt_variable)
+
+ # # print(other_independent_variables)
+ # # print(tuple(variable for monomial in monomial_terms.keys() for variable in monomial))
+
+ # if sum(variable not in other_independent_variables for monomial in monomial_terms.keys() for variable in monomial) < 2:
+ # return {}, state
+
+ # # if not implement_derivation:
+ # # implement_derivation = any(diff_wrt_variable in monomial for monomial in monomial_terms.keys())
+
+ # for monomial, value in monomial_terms.items():
+
+ # # count powers for each variable
+ # monomial_cnt = dict(collections.Counter(monomial))
+
+ # def differentiate_monomial(dependent_variable, derivation_variable=None):
+ # def gen_diff_monomial():
+ # for current_variable, current_count in monomial_cnt.items():
+
+ # if current_variable is dependent_variable:
+ # sel_counter = current_count - 1
+
+ # else:
+ # sel_counter = current_count
+
+ # for _ in range(sel_counter):
+ # yield current_variable
+
+ # if derivation_variable is not None:
+ # yield derivation_variable
+
+ # diff_monomial = tuple(gen_diff_monomial())
+
+ # return diff_monomial, value * monomial_cnt[dependent_variable]
+
+ # if diff_wrt_variable in monomial_cnt:
+ # diff_monomial, value = differentiate_monomial(diff_wrt_variable)
+ # derivation_terms[diff_monomial] += value
+
+ # if self.introduce_derivatives:
+
+ # def gen_derivation_keys():
+ # for variable in monomial_cnt.keys():
+ # if variable not in diff_wrt_variables:
+ # yield variable
+
+ # candidate_variables = tuple(gen_derivation_keys())
+
+ # new_considered_derivations = considered_variables | set(candidate_variables)
+
+ # for candidate_variable in candidate_variables:
+
+ # # introduce new auxillary equation
+ # if candidate_variable not in considered_variables:
+ # auxillary_derivation_terms, state = get_derivative_terms(
+ # monomial_terms=state.auxillary_equations[candidate_variable],
+ # diff_wrt_variable=diff_wrt_variable,
+ # state=state,
+ # considered_variables=new_considered_derivations,
+ # # implement_derivation=implement_derivation,
+ # )
+
+ # if 0 < len(auxillary_derivation_terms):
+ # key = init_derivative_key(
+ # variable=candidate_variable,
+ # with_respect_to=diff_wrt_variable,
+ # )
+ # state = state.register(key=key, n_param=1)
+ # derivation_variable = state.offset_dict[key][0]
+
+ # state = dataclasses.replace(
+ # state,
+ # auxillary_equations=state.auxillary_equations | {derivation_variable: auxillary_derivation_terms},
+ # )
+
+ # else:
+
+ # key = init_derivative_key(
+ # variable=candidate_variable,
+ # with_respect_to=diff_wrt_variable,
+ # )
+ # state = state.register(key=key, n_param=1)
+ # derivation_variable = state.offset_dict[key][0]
+
+ # diff_monomial, value = differentiate_monomial(
+ # dependent_variable=candidate_variable,
+ # derivation_variable=derivation_variable,
+ # )
+ # derivation_terms[diff_monomial] += value
+
+ # return dict(derivation_terms), state
+
+ # terms = {}
+
+ # for row in range(self.shape[0]):
+
+ # try:
+ # underlying_terms = underlying.get_poly(row, 0)
+ # except KeyError:
+ # continue
+
+ # # derivate each variable and map result to the corresponding column
+ # for col, diff_wrt_variable in enumerate(diff_wrt_variables):
+
+ # derivation_terms, state = get_derivative_terms(
+ # monomial_terms=underlying_terms,
+ # diff_wrt_variable=diff_wrt_variable,
+ # state=state,
+ # considered_variables=set(),
+ # # implement_derivation=False,
+ # )
+
+ # if 0 < len(derivation_terms):
+ # terms[row, col] = derivation_terms
+
+ # poly_matrix = init_poly_matrix(
+ # terms=terms,
+ # shape=self.shape,
+ # )
+
+ # return state, poly_matrix
+
+
+
+ # state = [state]
+
+ # state, underlying = self.underlying.apply(state=state)
+
+ # match self.variables:
+ # case ExpressionBaseMixin():
+ # assert self.variables.shape[1] == 1
+
+ # state, dependent_variables = self.variables.apply(state)
+
+ # def gen_indices():
+ # for row in range(dependent_variables.shape[0]):
+ # for monomial in dependent_variables.get_poly(row, 0).keys():
+ # yield monomial[0]
+
+ # variable_indices = tuple(sorted(gen_indices()))
+
+ # case _:
+ # def gen_indices():
+ # for variable in self.variables:
+ # if variable in state.offset_dict:
+ # yield state.offset_dict[variable][0]
+
+ # variable_indices = tuple(sorted(gen_indices()))
+
+ # terms = {}
+ # derivations_keys = set()
+
+ # # derivate each variable and map result to the corresponding column
+ # for col, derivation_variable in enumerate(variable_indices):
+
+ # def get_derivative_terms(monomial_terms):
+
+ # terms_row_col = collections.defaultdict(float)
+
+ # for monomial, value in monomial_terms.items():
+
+ # # count powers for each variable
+ # monomial_cnt = dict(collections.Counter(monomial))
+
+ # variable_candidates = tuple()
+
+ # if derivation_variable in monomial_cnt:
+ # variable_candidates += ((derivation_variable, None),)
+
+ # if self.introduce_derivatives:
+ # def gen_dependent_variables():
+ # for dependent_variable in monomial_cnt.keys():
+ # if dependent_variable not in variable_indices:
+ # derivation_key = init_derivative_key(
+ # variable=dependent_variable,
+ # with_respect_to=derivation_variable,
+ # )
+ # derivations_keys.add(derivation_key)
+ # state = state.register(key=derivation_key, n_param=1)
+ # yield dependent_variable, derivation_key
+
+ # variable_candidates += tuple(gen_dependent_variables())
+
+ # for variable_candidate, derivation_key in variable_candidates:
+
+ # def generate_monomial():
+ # for current_variable, current_count in monomial_cnt.items():
+
+ # if current_variable is variable_candidate:
+ # sel_counter = current_count - 1
+
+ # else:
+ # sel_counter = current_count
+
+ # for _ in range(sel_counter):
+ # yield current_variable
+
+ # if derivation_key is not None:
+ # yield state.offset_dict[derivation_key][0]
+
+ # col_monomial = tuple(generate_monomial())
+
+ # terms_row_col[col_monomial] += value * monomial_cnt[variable_candidate]
+
+ # return dict(terms_row_col)
+
+ # for row in range(self.shape[0]):
+
+ # try:
+ # underlying_terms = underlying.get_poly(row, 0)
+ # except KeyError:
+ # continue
+
+ # derivative_terms = get_derivative_terms(underlying_terms)
+
+ # if 0 < len(derivative_terms):
+ # terms[row, col] = derivative_terms
+
+ # derivation_variables = collections.defaultdict(list)
+ # for derivation_key in derivations_keys:
+ # derivation_variables[derivation_key.with_respect_to].append(derivation_key)
+
+ # aux_der_terms = []
+
+ # for derivation_variable, derivation_keys in derivation_variables.items():
+
+ # dependent_variables = tuple(derivation_key.variable for derivation_key in derivation_keys)
+
+
+ # for aux_terms in state.auxillary_equations:
+
+ # # only intoduce a new auxillary equation if there is a monomial containing at least one dependent variable
+ # if any(variable in dependent_variables for monomial in aux_terms.keys() for variable in monomial):
+
+ # terms_row_col = collections.defaultdict(float)
+
+ # # for each monomial
+ # for aux_monomial, value in aux_terms.items():
+
+ # # count powers for each variable
+ # monomial_cnt = dict(collections.Counter(aux_monomial))
+
+ # variable_candidates = tuple()
+
+ # if derivation_variable in monomial_cnt:
+ # variable_candidates += ((derivation_variable, None),)
+
+ # # add dependent variables
+ # variable_candidates += tuple((derivation_key.variable, derivation_key) for derivation_key in derivation_keys if derivation_key.variable in monomial_cnt)
+
+ # for variable_candidate, derivative_key in variable_candidates:
+
+ # def generate_monomial():
+ # for current_variable, current_count in monomial_cnt.items():
+
+ # if current_variable is variable_candidate:
+ # sel_counter = current_count - 1
+
+ # else:
+ # sel_counter = current_count
+
+ # for _ in range(sel_counter):
+ # yield current_variable
+
+ # if derivative_key is not None:
+ # yield state.offset_dict[derivative_key][0]
+
+ # col_monomial = tuple(generate_monomial())
+
+ # terms_row_col[col_monomial] += value * monomial_cnt[variable_candidate]
+
+ # if 0 < len(terms_row_col):
+ # aux_der_terms.append(dict(terms_row_col))
+
+ # state = dataclasses.replace(
+ # state,
+ # auxillary_equations=state.auxillary_equations + tuple(aux_der_terms),
+ # )
+
+ # poly_matrix = init_poly_matrix(
+ # terms=terms,
+ # shape=self.shape,
+ # )
+
+ # return state, poly_matrix
diff --git a/polymatrix/expression/mixins/determinantexprmixin.py b/polymatrix/expression/mixins/determinantexprmixin.py
index 8fe2536..048b263 100644
--- a/polymatrix/expression/mixins/determinantexprmixin.py
+++ b/polymatrix/expression/mixins/determinantexprmixin.py
@@ -16,10 +16,10 @@ class DeterminantExprMixin(ExpressionBaseMixin):
def underlying(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
- @property
- def shape(self) -> tuple[int, int]:
- return self.underlying.shape[0], 1
+ # # overwrites abstract method of `ExpressionBaseMixin`
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # return self.underlying.shape[0], 1
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
@@ -34,12 +34,12 @@ class DeterminantExprMixin(ExpressionBaseMixin):
assert underlying.shape[0] == underlying.shape[1]
inequality_terms = {}
- auxillary_terms = []
+ auxillary_equations = {}
index_start = state.n_param
rel_index = 0
- for row in range(self.shape[0]):
+ for row in range(underlying.shape[0]):
current_inequality_terms = collections.defaultdict(float)
@@ -66,13 +66,15 @@ class DeterminantExprMixin(ExpressionBaseMixin):
new_monomial = monomial + (index_start + rel_index + inner_row,)
current_inequality_terms[new_monomial] -= value
+ # auxillary terms
+ # ---------------
+
auxillary_term = collections.defaultdict(float)
for inner_col in range(row):
# P@x in P@x-v
key = tuple(reversed(sorted((inner_row, inner_col))))
- # terms = underlying.get_poly(*key)
try:
underlying_terms = underlying.get_poly(*key)
except KeyError:
@@ -83,7 +85,6 @@ class DeterminantExprMixin(ExpressionBaseMixin):
auxillary_term[new_monomial] += value
# -v in P@x-v
- # terms = underlying.get_poly(row, inner_row)
try:
underlying_terms = underlying.get_poly(row, inner_row)
except KeyError:
@@ -92,23 +93,23 @@ class DeterminantExprMixin(ExpressionBaseMixin):
for monomial, value in underlying_terms.items():
auxillary_term[monomial] -= value
- auxillary_terms.append(dict(auxillary_term))
+ x_variable = index_start + rel_index + inner_row
+ assert x_variable not in state.auxillary_equations
+ auxillary_equations[x_variable] = dict(auxillary_term)
rel_index += row
inequality_terms[row, 0] = dict(current_inequality_terms)
state = state.register(rel_index)
- # print(f'{auxillary_terms=}')
-
poly_matrix = init_poly_matrix(
terms=inequality_terms,
- shape=self.shape,
+ shape=(underlying.shape[0], 1),
)
state = dataclasses.replace(
state,
- auxillary_terms=state.auxillary_terms + tuple(auxillary_terms),
+ auxillary_equations=state.auxillary_equations | auxillary_equations,
cached_polymatrix=state.cached_polymatrix | {self: poly_matrix},
)
diff --git a/polymatrix/expression/mixins/divisionexprmixin.py b/polymatrix/expression/mixins/divisionexprmixin.py
index 01d3505..de5f685 100644
--- a/polymatrix/expression/mixins/divisionexprmixin.py
+++ b/polymatrix/expression/mixins/divisionexprmixin.py
@@ -19,16 +19,17 @@ class DivisionExprMixin(ExpressionBaseMixin):
def right(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
- @property
- def shape(self) -> tuple[int, int]:
- return self.left.shape
+ # # overwrites abstract method of `ExpressionBaseMixin`
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # return self.left.shape
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
self,
state: PolyMatrixExprState,
) -> tuple[PolyMatrixExprState, PolyMatrix]:
+
if self in state.cached_polymatrix:
return state, state.cached_polymatrix[self]
@@ -42,8 +43,8 @@ class DivisionExprMixin(ExpressionBaseMixin):
division_variable = state.n_param
state = state.register(n_param=1)
- for row in range(self.shape[0]):
- for col in range(self.shape[1]):
+ for row in range(left.shape[0]):
+ for col in range(left.shape[1]):
try:
underlying_terms = left.get_poly(row, col)
@@ -69,12 +70,12 @@ class DivisionExprMixin(ExpressionBaseMixin):
poly_matrix = init_poly_matrix(
terms=terms,
- shape=self.shape,
+ shape=left.shape,
)
state = dataclasses.replace(
state,
- auxillary_terms=state.auxillary_terms + (auxillary_terms,),
+ auxillary_equations=state.auxillary_equations | {division_variable: auxillary_terms},
cached_polymatrix=state.cached_polymatrix | {self: poly_matrix},
)
diff --git a/polymatrix/expression/mixins/elemmultexprmixin.py b/polymatrix/expression/mixins/elemmultexprmixin.py
index 9684827..1f6172e 100644
--- a/polymatrix/expression/mixins/elemmultexprmixin.py
+++ b/polymatrix/expression/mixins/elemmultexprmixin.py
@@ -19,10 +19,10 @@ class ElemMultExprMixin(ExpressionBaseMixin):
def right(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
- @property
- def shape(self) -> tuple[int, int]:
- return self.left.shape
+ # # overwrites abstract method of `ExpressionBaseMixin`
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # return self.left.shape
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
@@ -68,7 +68,7 @@ class ElemMultExprMixin(ExpressionBaseMixin):
poly_matrix = init_poly_matrix(
terms=terms,
- shape=self.shape,
+ shape=left.shape,
)
return state, poly_matrix
diff --git a/polymatrix/expression/mixins/evalexprmixin.py b/polymatrix/expression/mixins/evalexprmixin.py
index b945064..88e0174 100644
--- a/polymatrix/expression/mixins/evalexprmixin.py
+++ b/polymatrix/expression/mixins/evalexprmixin.py
@@ -24,10 +24,10 @@ class EvalExprMixin(ExpressionBaseMixin):
def eval_values(self) -> tuple[float, ...]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
- @property
- def shape(self) -> tuple[int, int]:
- return self.underlying.shape
+ # # overwrites abstract method of `ExpressionBaseMixin`
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # return self.underlying.shape
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
@@ -47,8 +47,8 @@ class EvalExprMixin(ExpressionBaseMixin):
terms = {}
- for row in range(self.shape[0]):
- for col in range(self.shape[1]):
+ for row in range(underlying.shape[0]):
+ for col in range(underlying.shape[1]):
try:
underlying_terms = underlying.get_poly(row, col)
@@ -87,7 +87,7 @@ class EvalExprMixin(ExpressionBaseMixin):
poly_matrix = init_poly_matrix(
terms=terms,
- shape=self.shape,
+ shape=underlying.shape,
)
return state, poly_matrix
diff --git a/polymatrix/expression/mixins/expressionbasemixin.py b/polymatrix/expression/mixins/expressionbasemixin.py
index d5109c0..43b3fdb 100644
--- a/polymatrix/expression/mixins/expressionbasemixin.py
+++ b/polymatrix/expression/mixins/expressionbasemixin.py
@@ -7,10 +7,10 @@ class ExpressionBaseMixin(
abc.ABC,
):
- @property
- @abc.abstractclassmethod
- def shape(self) -> tuple[int, int]:
- ...
+ # @property
+ # @abc.abstractclassmethod
+ # def shape(self) -> tuple[int, int]:
+ # ...
@abc.abstractmethod
def apply(self, state: ExpressionStateMixin) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
diff --git a/polymatrix/expression/mixins/expressionmixin.py b/polymatrix/expression/mixins/expressionmixin.py
index e50f660..6806293 100644
--- a/polymatrix/expression/mixins/expressionmixin.py
+++ b/polymatrix/expression/mixins/expressionmixin.py
@@ -3,6 +3,7 @@ import dataclasses
import typing
import numpy as np
from sympy import re
+from polymatrix.expression.init.initaccumulateexpr import init_accumulate_expr
from polymatrix.expression.init.initderivativeexpr import init_derivative_expr
from polymatrix.expression.init.initdeterminantexpr import init_determinant_expr
from polymatrix.expression.init.initdivisionexpr import init_division_expr
@@ -14,6 +15,7 @@ from polymatrix.expression.init.initgetitemexpr import init_get_item_expr
from polymatrix.expression.init.initparametrizesymbolsexpr import init_parametrize_symbols_expr
from polymatrix.expression.init.initquadraticinexpr import init_quadratic_in_expr
from polymatrix.expression.init.initrepmatexpr import init_rep_mat_expr
+from polymatrix.expression.init.inittoquadraticexpr import init_to_quadratic_expr
from polymatrix.expression.init.inittransposeexpr import init_transpose_expr
from polymatrix.init.initpolymatrixaddexpr import init_poly_matrix_add_expr
@@ -38,19 +40,19 @@ class ExpressionMixin(
def apply(self, state: PolyMatrixExprState) -> tuple[PolyMatrixExprState, PolyMatrix]:
return self.underlying.apply(state)
- # overwrites abstract method of `PolyMatrixExprBaseMixin`
- @property
- def degrees(self) -> set[int]:
- return self.underlying.degrees
+ # # overwrites abstract method of `PolyMatrixExprBaseMixin`
+ # @property
+ # def degrees(self) -> set[int]:
+ # return self.underlying.degrees
- # overwrites abstract method of `PolyMatrixExprBaseMixin`
- @property
- def shape(self) -> tuple[int, int]:
- return self.underlying.shape
+ # # overwrites abstract method of `PolyMatrixExprBaseMixin`
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # return self.underlying.shape
- def __iter__(self):
- for row in range(self.shape[0]):
- yield self[row, 0]
+ # def __iter__(self):
+ # for row in range(self.shape[0]):
+ # yield self[row, 0]
def __add__(self, other: ExpressionBaseMixin) -> 'ExpressionMixin':
# assert self.underlying.shape == other.shape, f'shapes {(self.shape, other.shape)} of polynomial matrix do not match!'
@@ -128,7 +130,7 @@ class ExpressionMixin(
underlying=self.underlying,
index=key,
),
- )
+ )
@property
def T(self) -> 'ExpressionMixin':
@@ -172,7 +174,7 @@ class ExpressionMixin(
),
)
- def for_all(self, variables: tuple) -> 'ExpressionMixin':
+ def linear_in(self, variables: tuple) -> 'ExpressionMixin':
return dataclasses.replace(
self,
underlying=init_for_all_expr(
@@ -211,3 +213,11 @@ class ExpressionMixin(
eval_values=value,
),
)
+
+ def to_quadratic(self) -> 'ExpressionMixin':
+ return dataclasses.replace(
+ self,
+ underlying=init_to_quadratic_expr(
+ underlying=self.underlying,
+ ),
+ )
diff --git a/polymatrix/expression/mixins/expressionstatemixin.py b/polymatrix/expression/mixins/expressionstatemixin.py
index 67085d4..526b422 100644
--- a/polymatrix/expression/mixins/expressionstatemixin.py
+++ b/polymatrix/expression/mixins/expressionstatemixin.py
@@ -24,7 +24,7 @@ class ExpressionStateMixin(abc.ABC):
@property
@abc.abstractmethod
- def auxillary_terms(self) -> tuple[dict[tuple[int], float]]:
+ def auxillary_equations(self) -> dict[int, dict[tuple[int], float]]:
...
@property
diff --git a/polymatrix/expression/mixins/forallexprmixin.py b/polymatrix/expression/mixins/forallexprmixin.py
index 4f5577c..9a28283 100644
--- a/polymatrix/expression/mixins/forallexprmixin.py
+++ b/polymatrix/expression/mixins/forallexprmixin.py
@@ -1,5 +1,6 @@
import abc
+import collections
from numpy import var
from polymatrix.expression.init.initpolymatrix import init_poly_matrix
@@ -19,10 +20,10 @@ class ForAllExprMixin(ExpressionBaseMixin):
def variables(self) -> tuple:
...
- # overwrites abstract method of `ExpressionBaseMixin`
- @property
- def shape(self) -> tuple[int, int]:
- return self.underlying.shape
+ # # overwrites abstract method of `ExpressionBaseMixin`
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # return self.underlying.shape
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
@@ -36,37 +37,31 @@ class ForAllExprMixin(ExpressionBaseMixin):
terms = {}
idx_row = 0
- for row in range(self.shape[0]):
- for col in range(self.shape[1]):
+ for row in range(underlying.shape[0]):
+ for col in range(underlying.shape[1]):
try:
underlying_terms = underlying.get_poly(row, col)
except KeyError:
continue
- x_monomial_terms = {}
+ x_monomial_terms = collections.defaultdict(lambda: collections.defaultdict(float))
for monomial, value in underlying_terms.items():
x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices)
p_monomial = tuple(var_idx for var_idx in monomial if var_idx not in variable_indices)
-
- if x_monomial not in x_monomial_terms:
- x_monomial_terms[x_monomial] = {}
-
- if p_monomial not in x_monomial_terms:
- x_monomial_terms[x_monomial][p_monomial] = 0
x_monomial_terms[x_monomial][p_monomial] += value
for data in x_monomial_terms.values():
- terms[idx_row, 0] = data
+ terms[idx_row, 0] = dict(data)
idx_row += 1
poly_matrix = init_poly_matrix(
terms=terms,
- shape=self.shape,
+ shape=(idx_row, 1),
)
return state, poly_matrix
diff --git a/polymatrix/expression/mixins/fromarrayexprmixin.py b/polymatrix/expression/mixins/fromarrayexprmixin.py
index 0ae210d..86f8b72 100644
--- a/polymatrix/expression/mixins/fromarrayexprmixin.py
+++ b/polymatrix/expression/mixins/fromarrayexprmixin.py
@@ -1,12 +1,7 @@
import abc
-import collections
-import typing
-import numpy as np
-import dataclass_abc
from numpy import poly
import sympy
-import functools
from polymatrix.expression.init.initpolymatrix import init_poly_matrix
from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
@@ -20,10 +15,10 @@ class FromArrayExprMixin(ExpressionBaseMixin):
def data(self) -> tuple[tuple[float]]:
pass
- # overwrites abstract method of `PolyMatrixExprBaseMixin`
- @property
- def shape(self) -> tuple[int, int]:
- return len(self.data), len(self.data[0])
+ # # overwrites abstract method of `PolyMatrixExprBaseMixin`
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # return len(self.data), len(self.data[0])
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
@@ -70,7 +65,7 @@ class FromArrayExprMixin(ExpressionBaseMixin):
poly_matrix = init_poly_matrix(
terms=terms,
- shape=self.shape,
+ shape=(len(self.data), len(self.data[0])),
)
return state, poly_matrix
diff --git a/polymatrix/expression/mixins/fromtermsexprmixin.py b/polymatrix/expression/mixins/fromtermsexprmixin.py
new file mode 100644
index 0000000..4ada8ec
--- /dev/null
+++ b/polymatrix/expression/mixins/fromtermsexprmixin.py
@@ -0,0 +1,36 @@
+
+import abc
+from numpy import poly
+import sympy
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.mixins.expressionstatemixin import ExpressionStateMixin
+from polymatrix.expression.mixins.polymatrixmixin import PolyMatrixMixin
+
+
+class FromTermsExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def terms(self) -> tuple[tuple[tuple[float], float], ...]:
+ pass
+
+ @property
+ @abc.abstractmethod
+ def shape(self) -> tuple[int, int]:
+ pass
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: ExpressionStateMixin,
+ ) -> tuple[ExpressionStateMixin, PolyMatrixMixin]:
+
+ terms = {coord: dict(monomials) for coord, monomials in self.terms}
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=self.shape,
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/getitemexprmixin.py b/polymatrix/expression/mixins/getitemexprmixin.py
index 9a521cb..57ffe38 100644
--- a/polymatrix/expression/mixins/getitemexprmixin.py
+++ b/polymatrix/expression/mixins/getitemexprmixin.py
@@ -22,10 +22,10 @@ class GetItemExprMixin(ExpressionBaseMixin):
def index(self) -> tuple[int, int]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
- @property
- def shape(self) -> tuple[int, int]:
- return 1, 1
+ # # overwrites abstract method of `ExpressionBaseMixin`
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # return 1, 1
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
@@ -48,7 +48,7 @@ class GetItemExprMixin(ExpressionBaseMixin):
return state, GetItemPolyMatrix(
underlying=underlying,
- shape=self.shape,
+ shape=(1, 1),
index=self.index,
# aux_terms=underlying.aux_terms,
)
diff --git a/polymatrix/expression/mixins/kktexprmixin.py b/polymatrix/expression/mixins/kktexprmixin.py
index f6066d8..ecc69ed 100644
--- a/polymatrix/expression/mixins/kktexprmixin.py
+++ b/polymatrix/expression/mixins/kktexprmixin.py
@@ -1,7 +1,6 @@
import abc
import itertools
-from this import d
import typing
import dataclass_abc
from polymatrix.expression.derivativekey import DerivativeKey
@@ -34,10 +33,10 @@ class KKTExprMixin(ExpressionBaseMixin):
def variables(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
- @property
- def shape(self) -> tuple[int, int]:
- return self.cost.shape[0] + self.equality.shape[0], 1
+ # # overwrites abstract method of `ExpressionBaseMixin`
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # return self.cost.shape[0] + self.equality.shape[0], 1
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
@@ -45,12 +44,14 @@ class KKTExprMixin(ExpressionBaseMixin):
state: PolyMatrixExprState,
) -> tuple[PolyMatrixExprState, PolyMatrix]:
- assert self.cost.shape == self.variables.shape
- assert self.cost.shape[1] == 1
-
state, cost = self.cost.apply(state=state)
state, equality = self.equality.apply(state=state)
+ # todo: check shape
+ # assert cost.shape == self.variables.shape
+
+ assert cost.shape[1] == 1
+
state, equality_der = self.equality.diff(
self.variables,
introduce_derivatives=True,
@@ -65,7 +66,7 @@ class KKTExprMixin(ExpressionBaseMixin):
return state, nu_variables + [nu_variable]
*_, (state, nu_variables) = tuple(itertools.accumulate(
- range(self.equality.shape[0]),
+ range(equality.shape[0]),
acc_nu_variables,
initial=(state, []),
))
@@ -98,7 +99,7 @@ class KKTExprMixin(ExpressionBaseMixin):
terms[row, 0] = monomial_terms
idx_start = n_row
-
+
for row in range(equality.shape[0]):
try:
@@ -110,14 +111,6 @@ class KKTExprMixin(ExpressionBaseMixin):
idx_start += equality.shape[0]
- # for row, aux_term in enumerate(state.auxillary_terms):
- # terms[idx_start + row, 0] = aux_term
-
- # idx_start += len(state.auxillary_terms)
-
- # derivatives = tuple(key for key in state.offset_dict.keys() if isinstance(key, DerivativeKey))
- # print(f'{derivatives=}')
-
poly_matrix = init_poly_matrix(
terms=terms,
shape=(idx_start, 1),
diff --git a/polymatrix/expression/mixins/matrixmultexprmixin.py b/polymatrix/expression/mixins/matrixmultexprmixin.py
index 29eb4ac..1fc4977 100644
--- a/polymatrix/expression/mixins/matrixmultexprmixin.py
+++ b/polymatrix/expression/mixins/matrixmultexprmixin.py
@@ -19,10 +19,10 @@ class MatrixMultExprMixin(ExpressionBaseMixin):
def right(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `ExpressionBaseMixin`
- @property
- def shape(self) -> tuple[int, int]:
- return (self.left.shape[0], self.right.shape[1])
+ # # overwrites abstract method of `ExpressionBaseMixin`
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # return (self.left.shape[0], self.right.shape[1])
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
@@ -69,7 +69,7 @@ class MatrixMultExprMixin(ExpressionBaseMixin):
poly_matrix = init_poly_matrix(
terms=terms,
- shape=self.shape,
+ shape=(left.shape[0], right.shape[1]),
)
return state, poly_matrix
diff --git a/polymatrix/expression/mixins/parametrizetermsexprmixin.py b/polymatrix/expression/mixins/parametrizetermsexprmixin.py
index fbee57c..9c9795d 100644
--- a/polymatrix/expression/mixins/parametrizetermsexprmixin.py
+++ b/polymatrix/expression/mixins/parametrizetermsexprmixin.py
@@ -26,9 +26,9 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin):
def variables(self) -> tuple:
...
- @property
- def shape(self) -> tuple[int, int]:
- return self.underlying.shape
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # return self.underlying.shape
def _internal_apply(self, state: ExpressionStateMixin):
if not hasattr(self, '_terms'):
@@ -40,12 +40,11 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin):
variable_indices = tuple(state.offset_dict[variable][0] for variable in self.variables if variable in state.offset_dict)
idx_start = state.n_param
- # print(f'{idx_start=}')
n_param = 0
terms = {}
- for row in range(self.shape[0]):
- for col in range(self.shape[1]):
+ for row in range(underlying.shape[0]):
+ for col in range(underlying.shape[1]):
try:
underlying_terms = underlying.get_poly(row, col)
@@ -61,9 +60,6 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin):
if x_monomial not in collected_terms:
collected_terms.append(x_monomial)
-
- # print(f'{x_monomial=}')
- # print(f'{collected_terms=}')
idx = idx_start + n_param + collected_terms.index(x_monomial)
@@ -78,7 +74,8 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin):
self._terms = terms
self._start_index = idx_start
- # self._n_param = n_param
+ self._shape = underlying.shape
+ self._n_param = n_param
return state, self._terms
@@ -86,44 +83,14 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin):
def param(self) -> tuple[int, int]:
outer_self = self
- # precalculate number of parameters (used for `shape` attribute)
- # ---------------------------------
- # not pretty
-
- dummy_state = init_expression_state()
-
- dummy_state, underlying = self.underlying.apply(dummy_state)
-
- for variable in self.variables:
- dummy_state = dummy_state.register(key=variable, n_param=1)
-
- variable_indices = tuple(dummy_state.offset_dict[variable][0] for variable in self.variables if variable in dummy_state.offset_dict)
-
- n_param = 0
-
- for row in range(self.shape[0]):
- for col in range(self.shape[1]):
-
- underlying_terms = underlying.get_poly(row, col)
-
- collected_terms = []
-
- for monomial in underlying_terms.keys():
-
- x_monomial = tuple(var_idx for var_idx in monomial if var_idx in variable_indices)
-
- if x_monomial not in collected_terms:
- collected_terms.append(x_monomial)
-
- n_param += len(collected_terms)
-
@dataclass_abc.dataclass_abc(frozen=True)
- class ParameterExprMixin(ExpressionBaseMixin):
- n_param: int
+ class ParameterExpr(ExpressionBaseMixin):
+ # n_param: int
+ # start_index: int
- @property
- def shape(self) -> tuple[int, int]:
- return self.n_param, 1
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # return self.n_param, 1
def apply(
self,
@@ -132,32 +99,23 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin):
state, _ = outer_self._internal_apply(state)
+ n_param = outer_self._n_param
+ start_index = outer_self._start_index
+
def gen_monomials():
- for rel_index in range(self.n_param):
- yield {(outer_self._start_index + rel_index,): 1}
+ for rel_index in range(n_param):
+ yield {(start_index + rel_index,): 1}
terms = {(row, 0): monomial_terms for row, monomial_terms in enumerate(gen_monomials())}
poly_matrix = init_poly_matrix(
terms=terms,
- shape=self.shape,
+ shape=(n_param, 1),
)
return state, poly_matrix
- return ParameterExprMixin(
- n_param=n_param,
- )
-
- # # overwrite this method to customize indexing
- # def re_index(
- # self,
- # degree: int,
- # poly_col: int,
- # poly_row: int,
- # x_monomial: tuple[int, ...],
- # ) -> tuple[int, int, tuple[int, ...], float]:
- # return poly_col, poly_row, x_monomial, 1.0
+ return ParameterExpr()
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
@@ -169,7 +127,7 @@ class ParametrizeTermsExprMixin(ExpressionBaseMixin):
poly_matrix = init_poly_matrix(
terms=terms,
- shape=self.shape,
+ shape=self._shape,
)
return state, poly_matrix
diff --git a/polymatrix/expression/mixins/polymatrixasdictmixin.py b/polymatrix/expression/mixins/polymatrixasdictmixin.py
index 6aea8ed..4a366de 100644
--- a/polymatrix/expression/mixins/polymatrixasdictmixin.py
+++ b/polymatrix/expression/mixins/polymatrixasdictmixin.py
@@ -14,14 +14,6 @@ class PolyMatrixAsDictMixin(
# overwrites abstract method of `PolyMatrixMixin`
def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
- # key = (row, col)
-
- # if key in self.terms:
- # return self.terms[key]
-
- # else:
- # return None
-
try:
return self.terms[row, col]
except KeyError:
diff --git a/polymatrix/expression/mixins/quadraticinexprmixin.py b/polymatrix/expression/mixins/quadraticinexprmixin.py
index 8f5555e..ee0e5a9 100644
--- a/polymatrix/expression/mixins/quadraticinexprmixin.py
+++ b/polymatrix/expression/mixins/quadraticinexprmixin.py
@@ -18,10 +18,10 @@ class QuadraticInExprMixin(ExpressionBaseMixin):
def variables(self) -> tuple:
...
- # overwrites abstract method of `ExpressionBaseMixin`
- @property
- def shape(self) -> tuple[int, int]:
- return 2*(len(self.variables),)
+ # # overwrites abstract method of `ExpressionBaseMixin`
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # return 2*(len(self.variables),)
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
@@ -66,7 +66,7 @@ class QuadraticInExprMixin(ExpressionBaseMixin):
poly_matrix = init_poly_matrix(
terms=terms,
- shape=self.shape,
+ shape=2*(len(self.variables),),
)
return state, poly_matrix
diff --git a/polymatrix/expression/mixins/repmatexprmixin.py b/polymatrix/expression/mixins/repmatexprmixin.py
index aec0d48..26777e9 100644
--- a/polymatrix/expression/mixins/repmatexprmixin.py
+++ b/polymatrix/expression/mixins/repmatexprmixin.py
@@ -16,9 +16,9 @@ class RepMatExprMixin(ExpressionBaseMixin):
def repetition(self) -> tuple[int, int]:
...
- @property
- def shape(self) -> tuple[int, int]:
- return tuple(s*r for s, r in zip(self.underlying.shape, self.repetition))
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # return tuple(s*r for s, r in zip(self.underlying.shape, self.repetition))
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
@@ -32,7 +32,6 @@ class RepMatExprMixin(ExpressionBaseMixin):
class RepMatPolyMatrix(PolyMatrixMixin):
underlying: PolyMatrixMixin
shape: tuple[int, int]
- # aux_terms: tuple
def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
n_row, n_col = underlying.shape
@@ -44,6 +43,5 @@ class RepMatExprMixin(ExpressionBaseMixin):
return state, RepMatPolyMatrix(
underlying=underlying,
- shape=self.shape,
- # aux_terms=underlying.aux_terms,
+ shape=tuple(s*r for s, r in zip(underlying.shape, self.repetition)),
) \ No newline at end of file
diff --git a/polymatrix/expression/mixins/toquadraticexprmixin.py b/polymatrix/expression/mixins/toquadraticexprmixin.py
new file mode 100644
index 0000000..d1b6fce
--- /dev/null
+++ b/polymatrix/expression/mixins/toquadraticexprmixin.py
@@ -0,0 +1,78 @@
+
+import abc
+import collections
+import dataclasses
+
+from polymatrix.expression.init.initpolymatrix import init_poly_matrix
+from polymatrix.expression.mixins.expressionbasemixin import ExpressionBaseMixin
+from polymatrix.expression.polymatrix import PolyMatrix
+from polymatrix.polymatrixexprstate import PolyMatrixExprState
+
+
+class ToQuadraticExprMixin(ExpressionBaseMixin):
+ @property
+ @abc.abstractmethod
+ def underlying(self) -> ExpressionBaseMixin:
+ ...
+
+ # # overwrites abstract method of `ExpressionBaseMixin`
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # return self.underlying.shape
+
+ # overwrites abstract method of `ExpressionBaseMixin`
+ def apply(
+ self,
+ state: PolyMatrixExprState,
+ ) -> tuple[PolyMatrixExprState, PolyMatrix]:
+ state, underlying = self.underlying.apply(state=state)
+
+ terms = {}
+ auxillary_equations = {}
+
+ for row in range(underlying.shape[0]):
+ for col in range(underlying.shape[1]):
+
+ try:
+ underlying_terms = underlying.get_poly(row, col)
+ except KeyError:
+ continue
+
+ terms_row_col = collections.defaultdict(float)
+
+ for monomial, value in underlying_terms.items():
+
+ if 2 < len(monomial):
+ current_aux = state.n_param
+ terms_row_col[(monomial[0], current_aux)] += value
+ state = state.register(n_param=1)
+
+ for variable in monomial[1:-2]:
+ auxillary_equations[current_aux] = {
+ (variable, current_aux + 1): 1,
+ (current_aux,): -1,
+ }
+ state = state.register(n_param=1)
+ current_aux += 1
+
+ auxillary_equations[current_aux] = {
+ (monomial[-2], monomial[-1]): 1,
+ (current_aux,): -1,
+ }
+
+ else:
+ terms_row_col[monomial] += value
+
+ terms[row, col] = dict(terms_row_col)
+
+ state = dataclasses.replace(
+ state,
+ auxillary_equations=state.auxillary_equations | auxillary_equations,
+ )
+
+ poly_matrix = init_poly_matrix(
+ terms=terms,
+ shape=underlying.shape,
+ )
+
+ return state, poly_matrix
diff --git a/polymatrix/expression/mixins/transposeexprmixin.py b/polymatrix/expression/mixins/transposeexprmixin.py
index 63e6a3f..4d09380 100644
--- a/polymatrix/expression/mixins/transposeexprmixin.py
+++ b/polymatrix/expression/mixins/transposeexprmixin.py
@@ -17,10 +17,10 @@ class TransposeExprMixin(ExpressionBaseMixin):
def underlying(self) -> ExpressionBaseMixin:
...
- # overwrites abstract method of `PolyMatrixExprBaseMixin`
- @property
- def shape(self) -> tuple[int, int]:
- return self.underlying.shape[1], self.underlying.shape[0]
+ # # overwrites abstract method of `PolyMatrixExprBaseMixin`
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # return self.underlying.shape[1], self.underlying.shape[0]
# overwrites abstract method of `PolyMatrixExprBaseMixin`
def apply(
@@ -39,5 +39,5 @@ class TransposeExprMixin(ExpressionBaseMixin):
return state, TransposePolyMatrix(
underlying=underlying,
- shape=self.shape,
+ shape=(underlying.shape[1], underlying.shape[0]),
) \ No newline at end of file
diff --git a/polymatrix/expression/mixins/vstackexprmixin.py b/polymatrix/expression/mixins/vstackexprmixin.py
index 381659c..0e95016 100644
--- a/polymatrix/expression/mixins/vstackexprmixin.py
+++ b/polymatrix/expression/mixins/vstackexprmixin.py
@@ -15,34 +15,33 @@ class VStackExprMixin(ExpressionBaseMixin):
def underlying(self) -> tuple[ExpressionBaseMixin, ...]:
...
- # overwrites abstract method of `ExpressionBaseMixin`
- @property
- def shape(self) -> tuple[int, int]:
- n_row = sum(expr.shape[0] for expr in self.underlying)
- return n_row, self.underlying[0].shape[1]
+ # # overwrites abstract method of `ExpressionBaseMixin`
+ # @property
+ # def shape(self) -> tuple[int, int]:
+ # n_row = sum(expr.shape[0] for expr in self.underlying)
+ # return n_row, self.underlying[0].shape[1]
# overwrites abstract method of `ExpressionBaseMixin`
def apply(
self,
state: PolyMatrixExprState,
) -> tuple[PolyMatrixExprState, PolyMatrix]:
- assert all(expr.shape[1] == self.underlying[0].shape[1] for expr in self.underlying)
-
- # todo: rename
- underlying = []
+
+ all_underlying = []
for expr in self.underlying:
state, polymat = expr.apply(state=state)
- underlying.append(polymat)
+ all_underlying.append(polymat)
+
+ assert all(underlying.shape[1] == all_underlying[0].shape[1] for underlying in all_underlying)
@dataclass_abc.dataclass_abc(frozen=True)
class VStackPolyMatrix(PolyMatrixMixin):
- underlying: tuple[PolyMatrixMixin]
+ all_underlying: tuple[PolyMatrixMixin]
underlying_row_range: tuple[tuple[int, int], ...]
shape: tuple[int, int]
- # aux_terms: tuple
def get_poly(self, row: int, col: int) -> dict[tuple[int, ...], float]:
- for polymat, (row_start, row_end) in zip(self.underlying, self.underlying_row_range):
+ for polymat, (row_start, row_end) in zip(self.all_underlying, self.underlying_row_range):
if row_start <= row < row_end:
return polymat.get_poly(
row=row - row_start,
@@ -53,14 +52,16 @@ class VStackExprMixin(ExpressionBaseMixin):
underlying_row_range = tuple(itertools.pairwise(
itertools.accumulate(
- (expr.shape[0] for expr in self.underlying),
+ (expr.shape[0] for expr in all_underlying),
initial=0)
))
- return state, VStackPolyMatrix(
- underlying=underlying,
- shape=self.shape,
+ n_row = sum(polymat.shape[0] for polymat in all_underlying)
+
+ polymat = VStackPolyMatrix(
+ all_underlying=all_underlying,
+ shape=(n_row, all_underlying[0].shape[1]),
underlying_row_range=underlying_row_range,
- # aux_terms=tuple(aux_term for expr in underlying for aux_term in expr.aux_terms)
)
- \ No newline at end of file
+
+ return state, polymat
diff --git a/polymatrix/expression/toquadraticexpr.py b/polymatrix/expression/toquadraticexpr.py
new file mode 100644
index 0000000..ff78d8c
--- /dev/null
+++ b/polymatrix/expression/toquadraticexpr.py
@@ -0,0 +1,4 @@
+from polymatrix.expression.mixins.toquadraticexprmixin import ToQuadraticExprMixin
+
+class ToQuadraticExpr(ToQuadraticExprMixin):
+ pass