有用的python package

argh--懒人版argparse

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import argh

def do_the_thing(required_arg, optional_arg=1, other_optional_arg=False):
"""
I am a docstring
"""
print((required_arg, type(required_arg)))
print((optional_arg, type(optional_arg)))
print((other_optional_arg, type(other_optional_arg)))


@argh.arg('--bool-arg-for-flag', '-b', help="Flip this flag for things")
@argh.arg('arg_with_choices', choices=['one', 'two', 'three'])
def do_the_other_thing(arg_with_choices, bool_arg_for_flag=False):
print(arg_with_choices)
print(bool_arg_for_flag)


if __name__ == '__main__':
# argh.dispatch_command(do_the_thing)
argh.dispatch_commands([do_the_thing, do_the_other_thing])

msgpack--二进制版json

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import msgpack
import json
import random


def msgpack_example_1():
example_dict = {i: random.random() for i in range(10000)}

with open('json_file.json', 'w') as f:
json.dump(example_dict, f)
with open('json_file.json') as f:
back_from_json = json.load(f)

# Saving and loading
with open('msgpack_file.msgpack', 'wb') as f:
# f.write(msgpack.packb(example_dict))
# f.write(msgpack.packb(example_dict, use_single_float=True))
f.write(msgpack.packb(example_dict))

with open('msgpack_file.msgpack', 'rb') as f:
back_from_msgpack = msgpack.unpackb(f.read())

# Data integrity
print(type(next(iter(back_from_json.keys()))))
print(type(next(iter(back_from_msgpack.keys()))))


def msgpack_example_2():
list_of_dicts = [{0: random.random()} for i in range(100)]
with open('streamed.msgpack', 'wb') as f:
for d in list_of_dicts:
f.write(msgpack.packb(d))

# 迭代读取
with open('streamed.msgpack', 'rb') as f:
loaded_list_of_dicts = [item for item in msgpack.Unpacker(f)]

print(list_of_dicts[3][0], loaded_list_of_dicts[3][0])

if __name__ == '__main__':
# msgpack_example_1()
msgpack_example_2()

redis_cache--使用redis缓存函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# sudo apt install redis-server
# sudo systemctl enable redis-server.service
# sudo systemctl start redis-server.service

# pip/pip3 install git+https://github.com/YashSinha1996/redis-simple-cache.git

import time
from redis_cache import cache_it, cache_it_json


@cache_it(limit=1000, expire=5)
def function_that_takes_a_long_time(i):
print(f"function was called with input {i}")
return i**2


if __name__ == '__main__':
for i in range(10):
print(i, function_that_takes_a_long_time(2))

schedule--定时运行函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import time
import schedule

def test_function():
print(f'test called at {time.time()}')

def test_function_2():
print(f'test 2 called at {time.time()}')

if __name__ == '__main__':
schedule.every(1).seconds.do(test_function)
schedule.every(3).seconds.do(test_function_2)
# schedule.every(1).days.do(daily_task)
# schedule.every().thursday.at("10:00").do(day_time_task)

while True:
schedule.run_pending()

tqdm--进度条显示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from tqdm import tqdm, trange
import random
import time


def tqdm_example_1():
for i in tqdm(range(10)):
time.sleep(0.2)


def tqdm_example_2():
for i in trange(10, desc="outer_loop"):
for j in trange(10, desc="inner_loop"):
time.sleep(0.01)


def tqdm_example_3(add_tot=False):
max_iter = 100
tot = 0

if add_tot:
bar = tqdm(desc="update example", total=max_iter)
else:
bar = tqdm()

while tot < max_iter:
update_iter = random.randint(1, 5)
bar.update(update_iter)
tot += update_iter
time.sleep(0.03)


def tqdm_example_4():
t = trange(100)
for i in t:
t.set_description(f"on iter {i}")
time.sleep(0.02)


if __name__ == "__main__":
# tqdm_example_1()
# tqdm_example_2()
# tqdm_example_3()
# tqdm_example_3(True)
tqdm_example_4()

Numba--矩阵运算加速

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import numpy as np
from numba import njit
from concurrent.futures import ThreadPoolExecutor


def test_func(x):
out=0
for i in range(100000000):
out += i
return out


def test_heavy_func(times):
arr = np.random.rand(10000, 10000)
return arr * arr


if __name__ == "__main__":
jitted_func = njit(test_func)
jitted_func_2 = njit(test_heavy_func)
# 计算使用 jit 并 关闭 gil,提升矩阵运算速度
jitted_func_3 = njit(test_heavy_func, nogil=True)


import time

# start = time.time()
# with ThreadPoolExecutor(4) as ex:
# ex.map(jitted_func, range(1000))
# end = time.time()
# print("[Python origin test] Used time: ", end - start)

start = time.time()
with ThreadPoolExecutor(4) as ex:
ex.map(jitted_func_2, range(100))
end = time.time()
print("[Numpy origin test] Used time: ", end - start)

start = time.time()
with ThreadPoolExecutor(4) as ex:
ex.map(jitted_func_3, range(100))
end = time.time()
print("[Numpy no gil test] Used time: ", end - start)

注意:在numba中使用一个普通的python列表不是一个好主意,因为它将花费很长时间来验证类型。

使用ndarray,才是正确的方法!才能带来速度的提升。

另外,@vectorize可以将处理一个元素的函数,转换成可以接受 array输入的优化函数,只是第一次使用需要对内存分配进行优化,会慢一些。

1
2
3
4
5
6
7
8
@vectorize(nopython=True)
def non_list_function(item):
if item % 2 == 0:
return 2
else:
return 1

non_list_function(test_list)
弹簧阻尼系统计算实例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def friction_fn(v, vt):
if v > vt:
return - v * 3
else:
return - vt * 3 * np.sign(v)

def simulate_spring_mass_funky_damper(x0, T=10, dt=0.0001, vt=1.0):
times = np.arange(0, T, dt)
positions = np.zeros_like(times)

v = 0
a = 0
x = x0
positions[0] = x0/x0

for ii in range(len(times)):
if ii == 0:
continue
t = times[ii]
a = friction_fn(v, vt) - 100*x
v = v + a*dt
x = x + v*dt
positions[ii] = x/x0
return times, positions
1
%time _ = simulate_spring_mass_funky_damper(0.1)

运行280ms,当输入x0为从0到10000,每次增加0.1,需要7个小时

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
@njit
def friction_fn(v, vt):
if v > vt:
return - v * 3
else:
return - vt * 3 * np.sign(v)

@njit
def simulate_spring_mass_funky_damper(x0, T=10, dt=0.0001, vt=1.0):
times = np.arange(0, T, dt)
positions = np.zeros_like(times)

v = 0
a = 0
x = x0
positions[0] = x0/x0

for ii in range(len(times)):
if ii == 0:
continue
t = times[ii]
a = friction_fn(v, vt) - 100*x
v = v + a*dt
x = x + v*dt
positions[ii] = x/x0
return times, positions

_ = simulate_spring_mass_funky_damper(0.1)

运行时间 1.99 ms,加速 200x。

再加速:

1
2
3
4
5
# 使用多线程
from concurrent.futures import ThreadPoolExecutor

with ThreadPoolExecutor(8) as ex:
ex.map(simulate_spring_mass_funky_damper, np.arange(0, 1000, 0.1))

当输入x0为从0到1000,每次增加0.1,需要19.3s。

再加速,利用多核,在矩阵运算时,关闭 GIL 锁:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
@njit(nogil=True)
def friction_fn(v, vt):
if v > vt:
return - v * 3
else:
return - vt * 3 * np.sign(v)

@njit(nogil=True)
def simulate_spring_mass_funky_damper(x0, T=10, dt=0.0001, vt=1.0):
times = np.arange(0, T, dt)
positions = np.zeros_like(times)

v = 0
a = 0
x = x0
positions[0] = x0/x0

for ii in range(len(times)):
if ii == 0:
continue
t = times[ii]
a = friction_fn(v, vt) - 100*x
v = v + a*dt
x = x + v*dt
positions[ii] = x/x0
return times, positions

# compile:先编译,那么使用时,省去了这段时间
_ = simulate_spring_mass_funky_damper(0.1)
1
2
3
4
from concurrent.futures import ThreadPoolExecutor

with ThreadPoolExecutor(8) as ex:
ex.map(simulate_spring_mass_funky_damper, np.arange(0, 1000, 0.1))

当输入x0为从0到1000,每次增加0.1,需要1.83s。

不使用多线程,使用numba自带的多进程并行,也是可以的

1
2
3
4
5
6
7
8
9
10
from numba import prange

@njit(nogil=True, parallel=True)
def run_sims(end=1000):
for x0 in prange(int(end/0.1)):
if x0 == 0:
continue
simulate_spring_mass_funky_damper(x0*0.1)

run_sims()

本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!