Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
"""Tools for manipulating of large commutative expressions. """
from __future__ import print_function, division
from sympy.core.add import Add from sympy.core.compatibility import iterable, is_sequence, SYMPY_INTS, range from sympy.core.mul import Mul, _keep_coeff from sympy.core.power import Pow from sympy.core.basic import Basic, preorder_traversal from sympy.core.expr import Expr from sympy.core.sympify import sympify from sympy.core.numbers import Rational, Integer, Number, I from sympy.core.singleton import S from sympy.core.symbol import Dummy from sympy.core.coreerrors import NonCommutativeExpression from sympy.core.containers import Tuple, Dict from sympy.utilities import default_sort_key from sympy.utilities.iterables import (common_prefix, common_suffix, variations, ordered)
from collections import defaultdict
_eps = Dummy(positive=True)
def _isnumber(i):
def _monotonic_sign(self): """Return the value closest to 0 that ``self`` may have if all symbols are signed and the result is uniformly the same sign for all values of symbols. If a symbol is only signed but not known to be an integer or the result is 0 then a symbol representative of the sign of self will be returned. Otherwise, None is returned if a) the sign could be positive or negative or b) self is not in one of the following forms:
- L(x, y, ...) + A: a function linear in all symbols x, y, ... with an additive constant; if A is zero then the function can be a monomial whose sign is monotonic over the range of the variables, e.g. (x + 1)**3 if x is nonnegative. - A/L(x, y, ...) + B: the inverse of a function linear in all symbols x, y, ... that does not have a sign change from positive to negative for any set of values for the variables. - M(x, y, ...) + A: a monomial M whose factors are all signed and a constant, A. - A/M(x, y, ...) + B: the inverse of a monomial and constants A and B. - P(x): a univariate polynomial
Examples ========
>>> from sympy.core.exprtools import _monotonic_sign as F >>> from sympy import Dummy, S >>> nn = Dummy(integer=True, nonnegative=True) >>> p = Dummy(integer=True, positive=True) >>> p2 = Dummy(integer=True, positive=True) >>> F(nn + 1) 1 >>> F(p - 1) _nneg >>> F(nn*p + 1) 1 >>> F(p2*p + 1) 2 >>> F(nn - 1) # could be negative, zero or positive """
if s.is_odd: return S(3) else: return S(2) return S(2) return S.One else: return S(-2) return S.NegativeOne else:
# univariate polynomial else: except (PolynomialError, NotImplementedError): roots = [r for r in roots(d, x) if r.is_real] else: else: if y: return Dummy('pos', positive=True) else: return Dummy('nneg', nonnegative=True) else: return Dummy('npos', nonpositive=True) else: elif v.is_nonnegative: return Dummy('nneg', nonnegative=True) elif v.is_negative: return Dummy('neg', negative=True) elif v.is_nonpositive: return Dummy('npos', nonpositive=True)
# multivariate # F/A or A/F where A is a number and F is a signed, rational monomial a.is_Mul or a.is_Pow) and \ a.is_rational and \ all(p.exp.is_Integer for p in a.atoms(Pow) if p.is_Pow) and \ (a.is_positive or a.is_negative): v = S(1) for ai in Mul.make_args(a): if ai.is_number: v *= ai continue reps = {} for x in ai.free_symbols: reps[x] = _monotonic_sign(x) if reps[x] is None: return v *= ai.subs(reps) # signed linear expression if not any(p for p in a.atoms(Pow) if not p.is_number) and (a.is_nonpositive or a.is_nonnegative): free = list(a.free_symbols) p = {} for i in free: v = _monotonic_sign(i) if v is None: return p[i] = v or (_eps if i.is_nonnegative else -_eps) v = a.xreplace(p) rv = v + c if v.is_nonnegative and rv.is_positive: return rv.subs(_eps, 0) if v.is_nonpositive and rv.is_negative: return rv.subs(_eps, 0)
def decompose_power(expr): """ Decompose power into symbolic base and integer exponent.
This is strictly only valid if the exponent from which the integer is extracted is itself an integer or the base is positive. These conditions are assumed and not checked here.
Examples ========
>>> from sympy.core.exprtools import decompose_power >>> from sympy.abc import x, y
>>> decompose_power(x) (x, 1) >>> decompose_power(x**2) (x, 2) >>> decompose_power(x**(2*y)) (x**y, 2) >>> decompose_power(x**(2*y/3)) (x**(y/3), 2)
"""
else: base, exp = expr, 1 else:
else:
def decompose_power_rat(expr): """ Decompose power into symbolic base and rational exponent.
""" base, exp = expr.as_base_exp()
if exp.is_Number: if not exp.is_Rational: base, exp = expr, 1 else: exp, tail = exp.as_coeff_Mul(rational=True)
if exp is S.NegativeOne: base, exp = Pow(base, tail), -1 elif exp is not S.One: tail = _keep_coeff(Rational(1, exp.q), tail) base, exp = Pow(base, tail), exp.p else: base, exp = expr, 1
return base, exp
class Factors(object): """Efficient representation of ``f_1*f_2*...*f_n``."""
__slots__ = ['factors', 'gens']
def __init__(self, factors=None): # Factors """Initialize Factors from dict or expr.
Examples ========
>>> from sympy.core.exprtools import Factors >>> from sympy.abc import x >>> from sympy import I >>> e = 2*x**3 >>> Factors(e) Factors({2: 1, x: 3}) >>> Factors(e.as_powers_dict()) Factors({2: 1, x: 3}) >>> f = _ >>> f.factors # underlying dictionary {2: 1, x: 3} >>> f.gens # base of each factor frozenset([2, x]) >>> Factors(0) Factors({0: 1}) >>> Factors(I) Factors({I: 1})
Notes =====
Although a dictionary can be passed, only minimal checking is performed: powers of -1 and I are made canonical.
""" factors = S(factors)
factors = factors.factors.copy() factors = {S.Zero: S.One} factors[n] = S.One # since we're processing Numbers, the denominator is # stored with a negative exponent; all other factors # are left . factors[Integer(n.p)] = S.One else: raise ValueError('Expected Float|Rational|Integer, not %s' % n) c.remove(I) factors[I] = S.One*i factors[Mul(*nc, evaluate=False)] = S.One else:
# tidy up -/+1 and I exponents if Rational
continue elif a.is_Pow: if S.NegativeOne not in factors: factors[S.NegativeOne] = S.Zero factors[S.NegativeOne] += a.exp elif a == 1: factors[a] = S.One elif a == -1: factors[-a] = S.One factors[S.NegativeOne] = S.One else: raise ValueError('unexpected factor in i1: %s' % a)
except AttributeError: raise TypeError('expecting Expr or dictionary')
def __hash__(self): # Factors keys = tuple(ordered(self.factors.keys())) values = [self.factors[k] for k in keys] return hash((keys, values))
def __repr__(self): # Factors return "Factors({%s})" % ', '.join( ['%s: %s' % (k, v) for k, v in ordered(self.factors.items())])
@property def is_zero(self): # Factors """ >>> from sympy.core.exprtools import Factors >>> Factors(0).is_zero True """
@property def is_one(self): # Factors """ >>> from sympy.core.exprtools import Factors >>> Factors(1).is_one True """ return not self.factors
def as_expr(self): # Factors """Return the underlying expression.
Examples ========
>>> from sympy.core.exprtools import Factors >>> from sympy.abc import x, y >>> Factors((x*y**2).as_powers_dict()).as_expr() x*y**2
"""
else: else:
def mul(self, other): # Factors """Return Factors of ``self * other``.
Examples ========
>>> from sympy.core.exprtools import Factors >>> from sympy.abc import x, y, z >>> a = Factors((x*y**2).as_powers_dict()) >>> b = Factors((x*y/z).as_powers_dict()) >>> a.mul(b) Factors({x: 2, y: 3, z: -1}) >>> a*b Factors({x: 2, y: 3, z: -1}) """ other = Factors(other) return Factors(S.Zero)
del factors[factor] continue
def normal(self, other): """Return ``self`` and ``other`` with ``gcd`` removed from each. The only differences between this and method ``div`` is that this is 1) optimized for the case when there are few factors in common and 2) this does not raise an error if ``other`` is zero.
See Also ======== div
""" other = Factors(other) if other.is_zero: return (Factors(), Factors(S.Zero)) if self.is_zero: return (Factors(S.Zero), Factors())
else: else: r = self_exp.extract_additively(other_exp) if r is not None: if r: self_factors[factor] = r del other_factors[factor] else: # should be handled already del self_factors[factor] del other_factors[factor] else: sc, sa = self_exp.as_coeff_Add() if sc: oc, oa = other_exp.as_coeff_Add() diff = sc - oc if diff > 0: self_factors[factor] -= oc other_exp = oa elif diff < 0: self_factors[factor] -= sc other_factors[factor] -= sc other_exp = oa - diff else: self_factors[factor] = sa other_exp = oa if other_exp: other_factors[factor] = other_exp else: del other_factors[factor]
def div(self, other): # Factors """Return ``self`` and ``other`` with ``gcd`` removed from each. This is optimized for the case when there are many factors in common.
Examples ========
>>> from sympy.core.exprtools import Factors >>> from sympy.abc import x, y, z >>> from sympy import S
>>> a = Factors((x*y**2).as_powers_dict()) >>> a.div(a) (Factors({}), Factors({})) >>> a.div(x*z) (Factors({y: 2}), Factors({z: 1}))
The ``/`` operator only gives ``quo``:
>>> a/x Factors({y: 2})
Factors treats its factors as though they are all in the numerator, so if you violate this assumption the results will be correct but will not strictly correspond to the numerator and denominator of the ratio:
>>> a.div(x/z) (Factors({y: 2}), Factors({z: -1}))
Factors is also naive about bases: it does not attempt any denesting of Rational-base terms, for example the following does not become 2**(2*x)/2.
>>> Factors(2**(2*x + 2)).div(S(8)) (Factors({2: 2*x + 2}), Factors({8: 1}))
factor_terms can clean up such Rational-bases powers:
>>> from sympy.core.exprtools import factor_terms >>> n, d = Factors(2**(2*x + 2)).div(S(8)) >>> n.as_expr()/d.as_expr() 2**(2*x + 2)/8 >>> factor_terms(_) 2**(2*x)/2
"""
raise ZeroDivisionError return (Factors(S.Zero), Factors())
quo[factor] = d
continue
exp = -d
else: r = quo[factor].extract_additively(exp) if r is not None: if r: quo[factor] = r else: # should be handled already del quo[factor] else: other_exp = exp sc, sa = quo[factor].as_coeff_Add() if sc: oc, oa = other_exp.as_coeff_Add() diff = sc - oc if diff > 0: quo[factor] -= oc other_exp = oa elif diff < 0: quo[factor] -= sc other_exp = oa - diff else: quo[factor] = sa other_exp = oa if other_exp: rem[factor] = other_exp else: assert factor not in rem continue
rem[factor] = exp
def quo(self, other): # Factors """Return numerator Factor of ``self / other``.
Examples ========
>>> from sympy.core.exprtools import Factors >>> from sympy.abc import x, y, z >>> a = Factors((x*y**2).as_powers_dict()) >>> b = Factors((x*y/z).as_powers_dict()) >>> a.quo(b) # same as a/b Factors({y: 1}) """
def rem(self, other): # Factors """Return denominator Factors of ``self / other``.
Examples ========
>>> from sympy.core.exprtools import Factors >>> from sympy.abc import x, y, z >>> a = Factors((x*y**2).as_powers_dict()) >>> b = Factors((x*y/z).as_powers_dict()) >>> a.rem(b) Factors({z: -1}) >>> a.rem(a) Factors({}) """ return self.div(other)[1]
def pow(self, other): # Factors """Return self raised to a non-negative integer power.
Examples ========
>>> from sympy.core.exprtools import Factors >>> from sympy.abc import x, y >>> a = Factors((x*y**2).as_powers_dict()) >>> a**2 Factors({x: 2, y: 4})
""" if isinstance(other, Factors): other = other.as_expr() if other.is_Integer: other = int(other) if isinstance(other, SYMPY_INTS) and other >= 0: factors = {}
if other: for factor, exp in self.factors.items(): factors[factor] = exp*other
return Factors(factors) else: raise ValueError("expected non-negative integer, got %s" % other)
def gcd(self, other): # Factors """Return Factors of ``gcd(self, other)``. The keys are the intersection of factors with the minimum exponent for each factor.
Examples ========
>>> from sympy.core.exprtools import Factors >>> from sympy.abc import x, y, z >>> a = Factors((x*y**2).as_powers_dict()) >>> b = Factors((x*y/z).as_powers_dict()) >>> a.gcd(b) Factors({x: 1, y: 1}) """ other = Factors(other) if other.is_zero: return Factors(self.factors)
def lcm(self, other): # Factors """Return Factors of ``lcm(self, other)`` which are the union of factors with the maximum exponent for each factor.
Examples ========
>>> from sympy.core.exprtools import Factors >>> from sympy.abc import x, y, z >>> a = Factors((x*y**2).as_powers_dict()) >>> b = Factors((x*y/z).as_powers_dict()) >>> a.lcm(b) Factors({x: 1, y: 2, z: -1}) """ other = Factors(other) if any(f.is_zero for f in (self, other)): return Factors(S.Zero)
exp = max(exp, factors[factor])
def __mul__(self, other): # Factors return self.mul(other)
def __divmod__(self, other): # Factors return self.div(other)
def __div__(self, other): # Factors return self.quo(other)
__truediv__ = __div__
def __mod__(self, other): # Factors return self.rem(other)
def __pow__(self, other): # Factors return self.pow(other)
def __eq__(self, other): # Factors if not isinstance(other, Factors): other = Factors(other) return self.factors == other.factors
def __ne__(self, other): # Factors return not self.__eq__(other)
class Term(object): """Efficient representation of ``coeff*(numer/denom)``. """
__slots__ = ['coeff', 'numer', 'denom']
def __init__(self, term, numer=None, denom=None): # Term raise NonCommutativeExpression( 'commutative expression expected')
else:
else:
numer = Factors()
denom = Factors()
def __hash__(self): # Term return hash((self.coeff, self.numer, self.denom))
def __repr__(self): # Term return "Term(%s, %s, %s)" % (self.coeff, self.numer, self.denom)
def as_expr(self): # Term
def mul(self, other): # Term
def inv(self): # Term
def quo(self, other): # Term
def pow(self, other): # Term if other < 0: return self.inv().pow(-other) else: return Term(self.coeff ** other, self.numer.pow(other), self.denom.pow(other))
def gcd(self, other): # Term self.numer.gcd(other.numer), self.denom.gcd(other.denom))
def lcm(self, other): # Term return Term(self.coeff.lcm(other.coeff), self.numer.lcm(other.numer), self.denom.lcm(other.denom))
def __mul__(self, other): # Term if isinstance(other, Term): return self.mul(other) else: return NotImplemented
def __div__(self, other): # Term if isinstance(other, Term): return self.quo(other) else: return NotImplemented
__truediv__ = __div__
def __pow__(self, other): # Term if isinstance(other, SYMPY_INTS): return self.pow(other) else: return NotImplemented
def __eq__(self, other): # Term return (self.coeff == other.coeff and self.numer == other.numer and self.denom == other.denom)
def __ne__(self, other): # Term return not self.__eq__(other)
def _gcd_terms(terms, isprimitive=False, fraction=True): """Helper function for :func:`gcd_terms`.
If ``isprimitive`` is True then the call to primitive for an Add will be skipped. This is useful when the content has already been extrated.
If ``fraction`` is True then the expression will appear over a common denominator, the lcm of all term denominators. """
terms = Add.make_args(terms)
# there is some simplification that may happen if we leave this # here rather than duplicate it before the mapping of Term onto # the terms return S.Zero, S.Zero, S.One
cont = terms[0].coeff numer = terms[0].numer.as_expr() denom = terms[0].denom.as_expr()
else:
else:
def gcd_terms(terms, isprimitive=False, clear=True, fraction=True): """Compute the GCD of ``terms`` and put them together.
``terms`` can be an expression or a non-Basic sequence of expressions which will be handled as though they are terms from a sum.
If ``isprimitive`` is True the _gcd_terms will not run the primitive method on the terms.
``clear`` controls the removal of integers from the denominator of an Add expression. When True (default), all numerical denominator will be cleared; when False the denominators will be cleared only if all terms had numerical denominators other than 1.
``fraction``, when True (default), will put the expression over a common denominator.
Examples ========
>>> from sympy.core import gcd_terms >>> from sympy.abc import x, y
>>> gcd_terms((x + 1)**2*y + (x + 1)*y**2) y*(x + 1)*(x + y + 1) >>> gcd_terms(x/2 + 1) (x + 2)/2 >>> gcd_terms(x/2 + 1, clear=False) x/2 + 1 >>> gcd_terms(x/2 + y/2, clear=False) (x + y)/2 >>> gcd_terms(x/2 + 1/x) (x**2 + 2)/(2*x) >>> gcd_terms(x/2 + 1/x, fraction=False) (x + 2/x)/2 >>> gcd_terms(x/2 + 1/x, fraction=False, clear=False) x/2 + 1/x
>>> gcd_terms(x/2/y + 1/x/y) (x**2 + 2)/(2*x*y) >>> gcd_terms(x/2/y + 1/x/y, clear=False) (x**2/2 + 1)/(x*y) >>> gcd_terms(x/2/y + 1/x/y, clear=False, fraction=False) (x/2 + 1/x)/y
The ``clear`` flag was ignored in this case because the returned expression was a rational expression, not a simple sum.
See Also ======== factor_terms, sympy.polys.polytools.terms_gcd
""" """replace nc portions of each term with a unique Dummy symbols and return the replacements to restore them""" nc = Mul._from_args(nc) d = Dummy() reps.append((d, nc)) c.append(d) args[i] = Mul._from_args(c) else:
is_sequence(terms, include=set) and \ not isinstance(terms, Dict)
else: for a in _numer.args):
if not isinstance(terms, Basic): return terms
if terms.is_Atom: return terms
if terms.is_Mul: c, args = terms.as_coeff_mul() return _keep_coeff(c, Mul(*[gcd_terms(i, isprimitive, clear, fraction) for i in args]), clear=clear)
def handle(a): # don't treat internal args like terms of an Add if not isinstance(a, Expr): if isinstance(a, Basic): return a.func(*[handle(i) for i in a.args]) return type(a)([handle(i) for i in a]) return gcd_terms(a, isprimitive, clear, fraction)
if isinstance(terms, Dict): return Dict(*[(k, handle(v)) for k, v in terms.args]) return terms.func(*[handle(i) for i in terms.args])
def factor_terms(expr, radical=False, clear=False, fraction=False, sign=True): """Remove common factors from terms in all arguments without changing the underlying structure of the expr. No expansion or simplification (and no processing of non-commutatives) is performed.
If radical=True then a radical common to all terms will be factored out of any Add sub-expressions of the expr.
If clear=False (default) then coefficients will not be separated from a single Add if they can be distributed to leave one or more terms with integer coefficients.
If fraction=True (default is False) then a common denominator will be constructed for the expression.
If sign=True (default) then even if the only factor in common is a -1, it will be factored out of the expression.
Examples ========
>>> from sympy import factor_terms, Symbol >>> from sympy.abc import x, y >>> factor_terms(x + x*(2 + 4*y)**3) x*(8*(2*y + 1)**3 + 1) >>> A = Symbol('A', commutative=False) >>> factor_terms(x*A + x*A + x*y*A) x*(y*A + 2*A)
When ``clear`` is False, a rational will only be factored out of an Add expression if all terms of the Add have coefficients that are fractions:
>>> factor_terms(x/2 + 1, clear=False) x/2 + 1 >>> factor_terms(x/2 + 1, clear=True) (x + 2)/2
If a -1 is all that can be factored out, to *not* factor it out, the flag ``sign`` must be False:
>>> factor_terms(-x - y) -(x + y) >>> factor_terms(-x - y, sign=False) -x - y >>> factor_terms(-2*x - 2*y, sign=False) -2*(x + y)
See Also ======== gcd_terms, sympy.polys.polytools.terms_gcd
"""
return type(expr)([do(i) for i in expr])
is_iterable or not hasattr(expr, 'args_cnc'):
# get a common negative (if there) which gcd_terms does not remove # watch out for exp(-(x+2)) which gcd_terms will change to exp(-x-2) list_args[i] = Dummy() special[list_args[i]] = a # rebuild p not worrying about the order which gcd_terms will fix isprimitive=True, clear=clear, fraction=fraction).xreplace(special) *[do(a) for a in p.args])
def _mask_nc(eq, name=None): """ Return ``eq`` with non-commutative objects replaced with Dummy symbols. A dictionary that can be used to restore the original values is returned: if it is None, the expression is noncommutative and cannot be made commutative. The third value returned is a list of any non-commutative symbols that appear in the returned equation.
``name``, if given, is the name that will be used with numered Dummy variables that will replace the non-commutative objects and is mainly used for doctesting purposes.
Notes ===== All non-commutative objects other than Symbols are replaced with a non-commutative Symbol. Identical objects will be identified by identical symbols.
If there is only 1 non-commutative object in an expression it will be replaced with a commutative symbol. Otherwise, the non-commutative entities are retained and the calling routine should handle replacements in this case since some care must be taken to keep track of the ordering of symbols when they occur within Muls.
Examples ========
>>> from sympy.physics.secondquant import Commutator, NO, F, Fd >>> from sympy import symbols, Mul >>> from sympy.core.exprtools import _mask_nc >>> from sympy.abc import x, y >>> A, B, C = symbols('A,B,C', commutative=False)
One nc-symbol:
>>> _mask_nc(A**2 - x**2, 'd') (_d0**2 - x**2, {_d0: A}, [])
Multiple nc-symbols:
>>> _mask_nc(A**2 - B**2, 'd') (A**2 - B**2, None, [A, B])
An nc-object with nc-symbols but no others outside of it:
>>> _mask_nc(1 + x*Commutator(A, B), 'd') (_d0*x + 1, {_d0: Commutator(A, B)}, []) >>> _mask_nc(NO(Fd(x)*F(y)), 'd') (_d0, {_d0: NO(CreateFermion(x)*AnnihilateFermion(y))}, [])
Multiple nc-objects:
>>> eq = x*Commutator(A, B) + x*Commutator(A, C)*Commutator(A, B) >>> _mask_nc(eq, 'd') (x*_d0 + x*_d1*_d0, {_d0: Commutator(A, B), _d1: Commutator(A, C)}, [_d0, _d1])
Multiple nc-objects and nc-symbols:
>>> eq = A*Commutator(A, B) + B*Commutator(A, C) >>> _mask_nc(eq, 'd') (A*_d0 + B*_d1, {_d0: Commutator(A, B), _d1: Commutator(A, C)}, [_d0, _d1, A, B])
If there is an object that:
- doesn't contain nc-symbols - but has arguments which derive from Basic, not Expr - and doesn't define an _eval_is_commutative routine
then it will give False (or None?) for the is_commutative test. Such objects are also removed by this routine:
>>> from sympy import Basic >>> eq = (1 + Mul(Basic(), Basic(), evaluate=False)) >>> eq.is_commutative False >>> _mask_nc(eq, 'd') (_d0**2 + 1, {_d0: Basic()}, [])
""" name = name or 'mask' # Make Dummy() append sequential numbers to the name
def numbered_names(): i = 0 while True: yield name + str(i) i += 1
names = numbered_names()
def Dummy(*args, **kwargs): from sympy import Dummy return Dummy(next(names), *args, **kwargs)
expr = eq if expr.is_commutative: return eq, {}, []
# identify nc-objects; symbols and other rep = [] nc_obj = set() nc_syms = set() pot = preorder_traversal(expr, keys=default_sort_key) for i, a in enumerate(pot): if any(a == r[0] for r in rep): pot.skip() elif not a.is_commutative: if a.is_Symbol: nc_syms.add(a) elif not (a.is_Add or a.is_Mul or a.is_Pow): if all(s.is_commutative for s in a.free_symbols): rep.append((a, Dummy())) else: nc_obj.add(a) pot.skip()
# If there is only one nc symbol or object, it can be factored regularly # but polys is going to complain, so replace it with a Dummy. if len(nc_obj) == 1 and not nc_syms: rep.append((nc_obj.pop(), Dummy())) elif len(nc_syms) == 1 and not nc_obj: rep.append((nc_syms.pop(), Dummy()))
# Any remaining nc-objects will be replaced with an nc-Dummy and # identified as an nc-Symbol to watch out for nc_obj = sorted(nc_obj, key=default_sort_key) for n in nc_obj: nc = Dummy(commutative=False) rep.append((n, nc)) nc_syms.add(nc) expr = expr.subs(rep)
nc_syms = list(nc_syms) nc_syms.sort(key=default_sort_key) return expr, {v: k for k, v in rep} or None, nc_syms
def factor_nc(expr): """Return the factored form of ``expr`` while handling non-commutative expressions.
Examples ========
>>> from sympy.core.exprtools import factor_nc >>> from sympy import Symbol >>> from sympy.abc import x >>> A = Symbol('A', commutative=False) >>> B = Symbol('B', commutative=False) >>> factor_nc((x**2 + 2*A*x + A**2).expand()) (x + A)**2 >>> factor_nc(((x + A)*(x + B)).expand()) (x + A)*(x + B) """ from sympy.simplify.simplify import powsimp from sympy.polys import gcd, factor
def _pemexpand(expr): "Expand with the minimal set of hints necessary to check the result." return expr.expand(deep=True, mul=True, power_exp=True, power_base=False, basic=False, multinomial=True, log=False)
expr = sympify(expr) if not isinstance(expr, Expr) or not expr.args: return expr if not expr.is_Add: return expr.func(*[factor_nc(a) for a in expr.args])
expr, rep, nc_symbols = _mask_nc(expr) if rep: return factor(expr).subs(rep) else: args = [a.args_cnc() for a in Add.make_args(expr)] c = g = l = r = S.One hit = False # find any commutative gcd term for i, a in enumerate(args): if i == 0: c = Mul._from_args(a[0]) elif a[0]: c = gcd(c, Mul._from_args(a[0])) else: c = S.One if c is not S.One: hit = True c, g = c.as_coeff_Mul() if g is not S.One: for i, (cc, _) in enumerate(args): cc = list(Mul.make_args(Mul._from_args(list(cc))/g)) args[i][0] = cc for i, (cc, _) in enumerate(args): cc[0] = cc[0]/c args[i][0] = cc # find any noncommutative common prefix for i, a in enumerate(args): if i == 0: n = a[1][:] else: n = common_prefix(n, a[1]) if not n: # is there a power that can be extracted? if not args[0][1]: break b, e = args[0][1][0].as_base_exp() ok = False if e.is_Integer: for t in args: if not t[1]: break bt, et = t[1][0].as_base_exp() if et.is_Integer and bt == b: e = min(e, et) else: break else: ok = hit = True l = b**e il = b**-e for i, a in enumerate(args): args[i][1][0] = il*args[i][1][0] break if not ok: break else: hit = True lenn = len(n) l = Mul(*n) for i, a in enumerate(args): args[i][1] = args[i][1][lenn:] # find any noncommutative common suffix for i, a in enumerate(args): if i == 0: n = a[1][:] else: n = common_suffix(n, a[1]) if not n: # is there a power that can be extracted? if not args[0][1]: break b, e = args[0][1][-1].as_base_exp() ok = False if e.is_Integer: for t in args: if not t[1]: break bt, et = t[1][-1].as_base_exp() if et.is_Integer and bt == b: e = min(e, et) else: break else: ok = hit = True r = b**e il = b**-e for i, a in enumerate(args): args[i][1][-1] = args[i][1][-1]*il break if not ok: break else: hit = True lenn = len(n) r = Mul(*n) for i, a in enumerate(args): args[i][1] = a[1][:len(a[1]) - lenn] if hit: mid = Add(*[Mul(*cc)*Mul(*nc) for cc, nc in args]) else: mid = expr
# sort the symbols so the Dummys would appear in the same # order as the original symbols, otherwise you may introduce # a factor of -1, e.g. A**2 - B**2) -- {A:y, B:x} --> y**2 - x**2 # and the former factors into two terms, (A - B)*(A + B) while the # latter factors into 3 terms, (-1)*(x - y)*(x + y) rep1 = [(n, Dummy()) for n in sorted(nc_symbols, key=default_sort_key)] unrep1 = [(v, k) for k, v in rep1] unrep1.reverse() new_mid, r2, _ = _mask_nc(mid.subs(rep1)) new_mid = powsimp(factor(new_mid))
new_mid = new_mid.subs(r2).subs(unrep1)
if new_mid.is_Pow: return _keep_coeff(c, g*l*new_mid*r)
if new_mid.is_Mul: # XXX TODO there should be a way to inspect what order the terms # must be in and just select the plausible ordering without # checking permutations cfac = [] ncfac = [] for f in new_mid.args: if f.is_commutative: cfac.append(f) else: b, e = f.as_base_exp() if e.is_Integer: ncfac.extend([b]*e) else: ncfac.append(f) pre_mid = g*Mul(*cfac)*l target = _pemexpand(expr/c) for s in variations(ncfac, len(ncfac)): ok = pre_mid*Mul(*s)*r if _pemexpand(ok) == target: return _keep_coeff(c, ok)
# mid was an Add that didn't factor successfully return _keep_coeff(c, g*l*mid*r) |