ahnobari
init
133d58c
raw
history blame contribute delete
No virus
6.15 kB
import os
import uuid
import argparse
argparser = argparse.ArgumentParser()
argparser.add_argument("--port", type=int, default=1239, help="Port number for the local server")
argparser.add_argument("--cuda_device", type=str, default='0', help="Cuda devices to use. Default is 0")
argparser.add_argument("--static_folder", type=str, default='static', help="Folder to store static files")
args = argparser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = ''
import gradio as gr
from pathlib import Path
import numpy as np
import torch
from LInK.CAD import create_3d_html
import numpy as np
from LInK.CurveUtils import uniformize
import torch
import matplotlib.pyplot as plt
from LInK.Solver import solve_rev_vectorized_batch_CPU
# turn off gradient computation
torch.set_grad_enabled(False)
results = np.load('alpha_res.npy',allow_pickle=True)
alphabet_test = np.load('alphabet.npy')
zs_ = np.load('alpha_z.npy',allow_pickle=True)
curves__ = torch.tensor(alphabet_test).float()
curves__ = curves__ - curves__.mean(1).unsqueeze(1)
max_idx = torch.square(curves__).sum(-1).argmax(dim=1)
theta = torch.atan2(curves__[torch.arange(curves__.shape[0]),max_idx,1],curves__[torch.arange(curves__.shape[0]),max_idx,0]).numpy()
curves_ = []
for i in range(len(results)):
curves_.append(results[i][-1])
curves_ = torch.tensor(curves_).float()
curves_ = uniformize(curves_,200)
curves_ = curves_ - curves_.mean(1).unsqueeze(1)
max_idx = torch.square(curves_).sum(-1).argmax(dim=1)
theta2 = torch.atan2(curves_[torch.arange(curves_.shape[0]),max_idx,1],curves_[torch.arange(curves_.shape[0]),max_idx,0]).numpy()
alphas = []
letter_heights = []
letter_centers = []
letter_widths = []
for i in range(len(results)):
A, x0, node_type, start_theta, end_theta, tr = results[i][0]
alpha = theta[i] - theta2[i]
if i == 21:
alpha -= np.pi/2.5
if i == 7:
alpha -= np.pi/3
if i == 4:
alpha += np.pi/36
alphas.append(alpha)
R = np.array([[np.cos(alpha), -np.sin(alpha)],[np.sin(alpha), np.cos(alpha)]]).squeeze()
transformed_curve = (R@results[i][-1].T).T
CD,OD,_ = results[i][1]
sol = solve_rev_vectorized_batch_CPU(A[None],x0[None],node_type[None],np.linspace(start_theta,end_theta,2000))[0]
sol_curve = (R@sol[-1].T).T
n_left = len(results) - i
n_left_row = 10 - i%10
letter_heights.append(transformed_curve[:,1].max()-transformed_curve[:,1].min())
letter_widths.append(transformed_curve[:,0].max()-transformed_curve[:,0].min())
letter_centers.append([(transformed_curve[:,0].max() + transformed_curve[:,0].min())/2,(transformed_curve[:,1].max() + transformed_curve[:,1].min())/2])
alphas = np.array(alphas)
letter_heights = np.array(letter_heights)
letter_centers = np.array(letter_centers)
letter_widths = np.array(letter_widths)
alphabet_dict = {'A':0,'B':1,'C':2,'D':3,'E':4,'F':5,'G':6,'H':7,'I':8,'J':9,'K':10,'L':11,'M':12,'N':13,'O':14,'P':16,'Q':15,'R':17,'S':18,'T':19,'U':20,'V':21,'W':22,'X':23,'Y':24,'Z':25}
def create_mech(target_text):
target_text = target_text.replace(' ','').upper()
target_height = 1.
spacing = 0.2
letters = [alphabet_dict[l] for l in target_text]
translations = []
scaling = []
transformed_curves = []
mechs = []
total_size = 0
for i,l in enumerate(letters):
A, x0, node_type, start_theta, end_theta, tr = results[l][0]
alpha = alphas[l]
R = np.array([[np.cos(alpha), -np.sin(alpha)],[np.sin(alpha), np.cos(alpha)]]).squeeze()
transformed_curve = (R@results[l][-1].T).T
s = target_height/letter_heights[l]
scaling.append(s)
if i>0:
trans = [translations[-1][0] + letter_widths[letters[i-1]]/2 * scaling[i-1] + letter_widths[l]/2 * s + spacing , 0]
else:
trans = [letter_widths[l]/2 * s ,0]
translations.append(trans)
transformed_curves.append(s*(transformed_curve - letter_centers[l]) + trans)
mechs.append([A,s*((R@x0.T).T - letter_centers[l]) + trans,node_type,start_theta+alpha,end_theta+alpha])
total_size += A.shape[0]
A_all = np.zeros((total_size,total_size))
x0_all = np.zeros((total_size,2))
node_type_all = np.zeros((total_size,1))
current_count = 0
sols = []
highlights = []
zs = []
for i,m in enumerate(mechs):
A, x0, node_type, start_theta, end_theta = m
A_all[current_count:current_count+A.shape[0],current_count:current_count+A.shape[0]] = A
# x0_all[current_count:current_count+A.shape[0]] = x0
node_type_all[current_count:current_count+A.shape[0]] = node_type
if i ==0:
highlights.append(A.shape[0]-1)
else:
highlights.append(A.shape[0]+highlights[-1])
sol = solve_rev_vectorized_batch_CPU(A[None],x0[None],node_type[None],np.linspace(start_theta,end_theta,100))[0]
sols.append(sol.transpose(1,0,2))
x0_all[current_count:current_count+A.shape[0]] = sol[:,0,:]
current_count += A.shape[0]
z = zs_[letters[i]]
zs.append(z + zs[-1].max() + 1 if i>0 else z)
sols = np.concatenate(sols,axis=1)
zs = np.concatenate(zs)
uuid_ = str(uuid.uuid4())
create_3d_html(A_all, x0_all, node_type_all, zs, np.concatenate([sols.transpose(1,0,2),sols.transpose(1,0,2)[:,::-1,:]],1), template_path = f'./static/animation.html', save_path=f'./static/{uuid_}.html', highlights=highlights)
return gr.HTML(f'<iframe width="100%" height="800px" src="file=static/{uuid_}.html"></iframe>',label="3D Plot",elem_classes="plot3d")
gr.set_static_paths(paths=[Path(f'./{args.static_folder}')])
with gr.Blocks() as block:
with gr.Row():
with gr.Column():
text = gr.Textbox(label="Enter a word (spaces will be ignored)", value='DECODE')
btn = gr.Button(value="Create Mechanism", variant="primary")
plot_3d = gr.HTML('<iframe width="100%" height="800px" src="file=static/filler.html"></iframe>',label="3D Plot",elem_classes="plot3d")
event1 = btn.click(create_mech, inputs=[text], outputs=[plot_3d])
block.launch()