LInKAlphabetDemo / LInK /Visulization.py
ahnobari
init
460c05d
raw
history blame contribute delete
No virus
3.48 kB
from .Solver import solve_mechanism
import numpy as np
import matplotlib.pyplot as plt
from io import StringIO
import xml.etree.ElementTree as etree
from svgpath2mpl import parse_path
def draw_mechanism(A,x0,fixed_nodes=None,motor=[0,1],node_types=None, ax=None, highlight=None, solve=True, thetas = np.linspace(0,np.pi*2,200), def_alpha = 1.0, h_alfa =1.0, h_c = "#f15a24"):
if fixed_nodes is None and node_types is None:
raise ValueError("Either fixed_nodes or node_types should be provided")
if fixed_nodes is None:
fixed_nodes = np.where(node_types)[0]
def fetch_path():
# r = requests.get(svg_url)
root = etree.parse(StringIO('<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 620 338"><defs><style>.cls-1{fill:#1a1a1a;stroke:#1a1a1a;stroke-linecap:round;stroke-miterlimit:10;stroke-width:20px;}</style></defs><path class="cls-1" d="M45.5,358.5l70.71-70.71M46,287.5H644m-507.61,71,70.72-70.71M223,358.5l70.71-70.71m20.18,70.72,70.71-70.71m13.67,70.7,70.71-70.71m20.19,70.72,70.71-70.71m15.84,70.71,70.71-70.71M345,39.62A121.38,121.38,0,1,1,223.62,161,121.38,121.38,0,0,1,345,39.62Z" transform="translate(-35.5 -29.62)"/></svg>')).getroot()
view_box = root.attrib.get('viewBox')
if view_box is not None:
view_box = [int(x) for x in view_box.split()]
xlim = (view_box[0], view_box[0] + view_box[2])
ylim = (view_box[1] + view_box[3], view_box[1])
else:
xlim = (0, 500)
ylim = (500, 0)
path_elem = root.findall('.//{http://www.w3.org/2000/svg}path')[0]
return xlim, ylim, parse_path(path_elem.attrib['d'])
_,_,p = fetch_path()
p.vertices -= p.vertices.mean(axis=0)
p.vertices = (np.array([[np.cos(np.pi), -np.sin(np.pi)],[np.sin(np.pi), np.cos(np.pi)]])@p.vertices.T).T
if ax is None:
fig, ax = plt.subplots(figsize=(10,10))
A,x0,fixed_nodes,motor = np.array(A),np.array(x0),np.array(fixed_nodes),np.array(motor)
x = x0
N = A.shape[0]
for i in range(N):
if i in fixed_nodes:
if i == highlight:
ax.scatter(x[i,0],x[i,1],color=h_c,s=700,zorder=10,marker=p)
else:
ax.scatter(x[i,0],x[i,1],color="#1a1a1a",s=700,zorder=10,marker=p)
else:
if i == highlight:
ax.scatter(x[i,0],x[i,1],color=h_c,s=100,zorder=10,facecolors=h_c,alpha=0.7)
else:
ax.scatter(x[i,0],x[i,1],color="#1a1a1a",s=100,zorder=10,facecolors='#ffffff',alpha=0.7)
for j in range(i+1,N):
if A[i,j]:
if (motor[0] == i and motor[1] == j) or(motor[0] == j and motor[1] == i):
ax.plot([x[i,0],x[j,0]],[x[i,1],x[j,1]],color="#ffc800",linewidth=4.5)
else:
ax.plot([x[i,0],x[j,0]],[x[i,1],x[j,1]],color="#1a1a1a",linewidth=4.5,alpha=0.6)
if solve:
sol,ord = solve_mechanism(A,x0,motor,fixed_nodes,thetas)
x = sol
for i in range(A.shape[0]):
if not ord[i] in fixed_nodes:
if ord[i] == highlight:
ax.plot(x[i,:,0],x[i,:,1],'-',color=h_c,linewidth=4.5,alpha=h_alfa)
else:
ax.plot(x[i,:,0],x[i,:,1],'--',color="#0078a7",linewidth=1.5, alpha=def_alpha)
ax.axis('equal')
ax.axis('off')
return ax