Inevitable

文字所表现出来的美丽、恐惧。。还有率真之类的东西,我觉得在其他事物中还是很难寻得的。

0%

初入数值分析,如何写好代码

写一些关于数学的代码,和开发程序是两种不同的感觉。接下来是一些我个人常用的 Tips。

Sum

求和符号频繁的出现在各种公式里面。比如 Composite Simpson Rule:

\(\int^a_b f(x)dx = \frac{h}{3}(f(a)+f(b)+4\sum^{n/2}_{i=1}f(a+(2i-1)h)+2\sum^{(n-2)/2}_{i=1}f(a+2ih))\)

我之前看到求和符号的第一反应是这儿又得用 for loop 了。例如 \(\sum^{n/2}_{i=1}f(a+(2i-1)h)\) 可能会用以下的代码来计算。

1
2
3
4
sum=0
for i in range(1,n//2+1):
sum = sum+f(a+(2*i-1)*h)
sum

这是一个有点冗长,不清楚的写法。

Map

我们可以用 np.summap 函数来实现。

1
2
import numpy as np
np.sum(np.array(list(map(lambda i:f(a+(2*i-1)*h),np.arange(1,n//2+1)))))

瞬间把之前的几行代码压缩成了一行。这样写的坏处是括号比较多,在没有括号高亮的情况下容易出现漏括号或者多括号的情况。

Vectorize

这个方法全靠 Numpy。

1
2
3
4
5
import numpy as np
i = np.arange(1,n//2+1)
sum1 = lambda i: f(a+(2*i-1)*h)
vfunc = np.vectorize(sum1)
np.sum(vfunc(i))

这个思路感觉和 map 的思路类似,但是优势是创建了一个可以复用的 Vectorize Function,可以接受数组的输入。

Another For Loop

这个其实和 for loop 没什么区别,知识短了一点。

1
2
import numpy as np
np.sum([f(a+(2*i-1)*h) for i in np.arange(1,n//2+1) ])

这个方法会比较灵活。在面对多个参数的时候会比较好用。比如有一个函数 f(a,b,c,d), 只有 c 这个参数需要变化,a,b,d 都是不要变化的。我们可以写这样写:

1
2
import numpy as np
[f(a,b,c,d) for c in np.arange(n)]

矩阵点乘

这里我们用一个简单一点的例子。我们需要计算\(\sum_{i=0}^{n}a_ib_i\)。这个的本质其实是 a 和 b 两个矩阵的点乘。

1
2
3
a=[...]
b=[...]
np.dot(a,b)

快速生成一个矩阵

这个很简单。例如生成一个 5*5 矩阵。

1
2
import numpy as np
np.zeros((5, 5))

Table

Tabulate

1
2
3
4
5
print(tabulate([["Name","Age"],["Alice",24],["Bob",19]],headers="firstrow"))
# Name Age
# ------ -----
# Alice 24
# Bob 19
1
2
3
4
5
6
7
print(tabulate({"Name": ["Alice", "Bob"],
"Age": [24, 19]}, headers="keys"))
# Age Name
# ----- ------
# 24 Alice
# 19 Bob

Plotly

1
2
3
4
5
6
import plotly.graph_objects as go

fig = go.Figure(data=[go.Table(header=dict(values=['A Scores', 'B Scores']),
cells=dict(values=[[100, 90, 80, 90], [95, 85, 75, 95]]))
])
fig.show()

Plot

我每次都记不住怎么画图。

matplotlib

1
2
3
4
5
6
7
8
9
10
11
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(9,6))
ax.loglog(hs,composite_trapezoid_rule_error_func1,'bo-',label='CTR ERROR',lw=2)
ax.loglog(hs,np.power(hs,2),'ro-',label='err(h) = h^2',lw=2)
ax.set_title("Error in CTR approximation",fontsize=22)
ax.legend(fontsize=15)
ax.set_xlabel('h',fontsize=22)
ax.set_ylabel('err(h)',fontsize=22)
ax.xaxis.set_tick_params(labelsize=15)
ax.yaxis.set_tick_params(labelsize=15)

伪代码翻译

数组的索引从 0 开始,但是很多伪代码是从 1 开始。为了方便翻译,我们可以在所有数组的开头插入一个 0,这样数组的有效数据便从索引 1 开始。这样伪代码和实际代码之间的索引便不会错了。以下为例子。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
def sor(a,b,XO,omega,TOL,N):
n=len(a)
x=np.zeros(n+1)
a = np.insert(a,0,0,axis=1)
a = np.insert(a,0,0,axis=0)
b = np.insert(b,0,0,axis=0)
XO = [0.0]+XO
# step 1
k=1
# step 2
while k<=N:
# step 3
for i in np.arange(1,n+1):
sum1=sum([a[i][j]*x[j] for j in np.arange(1,i)])
sum2=sum([a[i][j]*XO[j] for j in np.arange(i+1,n+1)])
x[i]=(1-omega)*XO[i]+1/a[i][i]*(omega*(np.negative(sum1)-sum2+b[i]))
# step 4
if max(np.abs(b-np.dot(a,x)))<TOL:
print(f'Number of Iterations: {k}')
return x[1:]
# step 5
k = k+1
# step 6
for i in np.arange(1,n+1):
XO[i]=x[i]
print('Maximum number of iterations exceed')
return XO

assert np.allclose(sor([[4,3,0],[3,4,-1],[0,-1,4]],[24,30,-24],[1.0,1.0,1.0],1.25,1e-05,1000),[3,4,-5], rtol=1e-05, atol=1e-08, equal_nan=False)

欢迎关注我的其它发布渠道