1
2 """
3 Functions for manipulating expressions.
4 """
5 import math
6 import re
7 from copy import copy
8 import park.util.safemath
9 from deps import order_dependencies
10
11
12
13
14 _symbol_pattern = re.compile('([a-zA-Z][a-zA-Z_0-9.]*)')
15
17 """
18 Given an expression string and a symbol table, return the set of symbols
19 used in the expression. Symbols are only returned once even if they
20 occur multiple times. The return value is a set with the elements in
21 no particular order.
22
23 This is the first step in computing a dependency graph.
24 """
25 matches = [m.group(0) for m in _symbol_pattern.finditer(expr)]
26 return set([symtab[m] for m in matches if m in symtab])
27
29 """
30 Replace all occurrences of symbol s with mapping[s] for s in mapping.
31 """
32
33 matches = [(m.start(),m.end(),mapping[m.group(1)])
34 for m in _symbol_pattern.finditer(expr)
35 if m.group(1) in mapping]
36
37
38 pieces = []
39 offset = 0
40 for start,end,text in matches:
41 pieces += [expr[offset:start],text]
42 offset = end
43 pieces.append(expr[offset:])
44
45
46 return "".join(pieces)
47
49 """
50 Returns a list of pair-wise dependencies from the parameter expressions.
51
52 For example, if p3 = p1+p2, then find_dependencies([p1,p2,p3]) will
53 return [(p3,p1),(p3,p2)]. For base expressions without dependencies,
54 such as p4 = 2*pi, this should return [(p4, None)]
55 """
56 symtab = dict([(p.path, p) for p in pars])
57
58
59
60
61
62 def symbols_or_none(expr,symtab):
63 syms = symbols(expr,symtab)
64 return syms if len(syms) else [None]
65 deps = [(p,dep)
66 for p in pars if p.iscomputed()
67 for dep in symbols_or_none(p.expression,symtab)]
68 return deps
69
71 """
72 Find the parameter substitution we need so that expressions can
73 be evaluated without having to traverse a chain of
74 model.layer.parameter.value
75 """
76 left,right = zip(*pairs)
77 pars = set(left+right)
78 symtab = dict( ('P%d'%i,p) for i,p in enumerate(pars) )
79
80 mapping = dict( (p.path,'P%d.value'%i)
81 for i,p in enumerate(pars)
82 if p is not None)
83 return symtab,mapping
84
86 """
87 This parameter set has no constraints between the parameters.
88 """
89 pass
90
92 """
93 Build and return a function to evaluate all parameter expressions in
94 the proper order.
95
96 Inputs:
97 pars is a list of parameters
98 context is a dictionary of additional symbols for the expressions
99
100 Output:
101 updater function
102
103 Raises:
104 AssertionError - model, parameter or function is missing
105 SyntaxError - improper expression syntax
106 ValueError - expressions have circular dependencies
107
108 This function is not terribly sophisticated, and it would be easy to
109 trick. However it handles the common cases cleanly and generates
110 reasonable messages for the common errors.
111
112 This code has not been fully audited for security. While we have
113 removed the builtins and the ability to import modules, there may
114 be other vectors for users to perform more than simple function
115 evaluations. Unauthenticated users should not be running this code.
116
117 Parameter names are assumed to contain only _.a-zA-Z0-9#[]
118
119 The list of parameters is probably something like::
120
121 parset.setprefix()
122 pars = parset.flatten()
123
124 Note that math uses acos while numpy uses arccos. To avoid confusion
125 we allow both.
126
127 Should try running the function to identify syntax errors before
128 running it in a fit.
129
130 Use help(fn) to see the code generated for the returned function fn.
131 dis.dis(fn) will show the corresponding python vm instructions.
132 """
133
134
135 deps = find_dependencies(pars)
136 if deps == []: return no_constraints
137 par_table,par_mapping = parameter_mapping(deps)
138 order = order_dependencies(deps)
139
140
141 globals = copy(park.util.safemath.context)
142 globals.update(context)
143 globals.update(par_table)
144 locals = {}
145
146
147 exprs = [p.path+"="+p.expression for p in order]
148 code = [substitute(s,par_mapping) for s in exprs]
149
150
151 functiondef = """
152 def eval_expressions():
153 '''
154 %s
155 '''
156 %s
157 """%("\n ".join(exprs),"\n ".join(code))
158
159
160 exec functiondef in globals,locals
161 retfn = locals['eval_expressions']
162
163
164 globals.pop('__doc__',None)
165 globals.pop('__name__',None)
166 globals.pop('__file__',None)
167 globals.pop('__builtins__')
168
169
170 return retfn
171
173 import inspect, dis
174 import math
175
176 symtab = {'a.b.x':1, 'a.c':2, 'a.b':3, 'b.x':4}
177 expr = 'a.b.x + sin(4*pi*a.c) + a.b.x/a.b'
178
179
180 assert symbols(expr, symtab) == set([1,2,3])
181
182
183 assert substitute(expr,{'a.b.x':'Q'}) == 'Q + sin(4*pi*a.c) + Q/a.b'
184
185
186
187 class Parameter:
188 def __init__(self, name, value=0, expression=''):
189 self.path = name
190 self.value = value
191 self.expression = expression
192 def iscomputed(self): return (self.expression != '')
193 def __repr__(self): return self.path
194 p1 = Parameter('G0.sigma',5)
195 p2 = Parameter('other',expression='2*pi*sin(G0.sigma/.1875) + M1.G1')
196 p3 = Parameter('M1.G1',6)
197 p4 = Parameter('constant',expression='2*pi*35')
198
199 assert set(find_dependencies([p1,p2,p3])) == set([(p2,p1),(p2,p3)])
200
201 assert set(find_dependencies([p1,p4])) == set([(p4,None)])
202
203 assert set(find_dependencies([p1,p3])) == set([])
204
205
206 fn = build_eval([p1,p2,p3])
207
208
209 if False:
210 print inspect.getdoc(fn)
211 print dis.dis(fn)
212
213
214
215 fn()
216 expected = 2*math.pi*math.sin(5/.1875) + 6
217 assert p2.value == expected,"Value was %s, not %s"%(p2.value,expected)
218
219
220 fn = build_eval([p1,p3])
221 fn()
222
223
224 fn = build_eval([p4])
225 fn()
226 assert p4.value == 2*math.pi*35
227
228
229
230 class Table:
231 Si = 2.09
232 values = {'Si': 2.07}
233 tbl = Table()
234 p5 = Parameter('lookup',expression="tbl.Si")
235 fn = build_eval([p1,p2,p3,p5],context=dict(tbl=tbl))
236 fn()
237 assert p5.value == 2.09,"Value for %s was %s"%(p5.expression,p5.value)
238 p5.expression = "tbl.values['Si']"
239 fn = build_eval([p1,p2,p3,p5],context=dict(tbl=tbl))
240 fn()
241 assert p5.value == 2.07,"Value for %s was %s"%(p5.expression,p5.value)
242
243
244
245 for expr in ['G4.cage', 'M0.cage', 'M1.G1 + *2',
246 'piddle',
247 'import sys; print "p0wned"',
248 '__import__("sys").argv']:
249 try:
250 p6 = Parameter('broken',expression=expr)
251 fn = build_eval([p6])
252 fn()
253 except Exception,msg: pass
254 else: raise "Failed to raise error for %s"%expr
255
256 if __name__ == "__main__": test()
257