K-Sort-Arena / serve /update_skill.py
ksort's picture
Update ssh
02f8ed6
raw
history blame
No virus
3.23 kB
import numpy as np
import json
from trueskill import TrueSkill
import paramiko
import io, os
import sys
from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_SKILL
trueskill_env = TrueSkill()
sys.path.append('../')
from model.models import IMAGE_GENERATION_MODELS
ssh_skill_client = None
sftp_skill_client = None
def create_ssh_skill_client(server, port, user, password):
global ssh_skill_client, sftp_skill_client
ssh_skill_client = paramiko.SSHClient()
ssh_skill_client.load_system_host_keys()
ssh_skill_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh_skill_client.connect(server, port, user, password)
sftp_skill_client = ssh_skill_client.open_sftp()
def ucb_score(trueskill_diff, t, n):
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
ucb = -trueskill_diff + 1.0 * exploration_term
return ucb
def update_trueskill(ratings, ranks):
new_ratings = trueskill_env.rate(ratings, ranks)
return new_ratings
def serialize_rating(rating):
return {'mu': rating.mu, 'sigma': rating.sigma}
def deserialize_rating(rating_dict):
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
global sftp_skill_client
data = {
'ratings': [serialize_rating(r) for r in ratings],
'comparison_counts': comparison_counts.tolist(),
'total_comparisons': total_comparisons
}
json_data = json.dumps(data)
with sftp_skill_client.open(SSH_SKILL, 'w') as f:
f.write(json_data)
def load_json_via_sftp():
global sftp_skill_client
with sftp_skill_client.open(SSH_SKILL, 'r') as f:
data = json.load(f)
ratings = [deserialize_rating(r) for r in data['ratings']]
comparison_counts = np.array(data['comparison_counts'])
total_comparisons = data['total_comparisons']
return ratings, comparison_counts, total_comparisons
def update_skill(rank, model_names, k_group=4):
ratings, comparison_counts, total_comparisons = load_json_via_sftp()
# group = Model_ID.group
group = []
for model_name in model_names:
group.append(IMAGE_GENERATION_MODELS.index(model_name))
print(group)
pairwise_comparisons = [(i, j) for i in range(len(group)) for j in range(i+1, len(group))]
for player1, player2 in pairwise_comparisons:
if rank[player1] < rank[player2]:
ranks = [0, 1]
updated_ratings = update_trueskill([[ratings[group[player1]]], [ratings[group[player2]]]], ranks)
ratings[group[player1]], ratings[group[player2]] = updated_ratings[0][0], updated_ratings[1][0]
elif rank[player1] > rank[player2]:
ranks = [1, 0]
updated_ratings = update_trueskill([[ratings[group[player1]]], [ratings[group[player2]]]], ranks)
ratings[group[player1]], ratings[group[player2]] = updated_ratings[0][0], updated_ratings[1][0]
comparison_counts[group[player1], group[player2]] += 1
comparison_counts[group[player2], group[player1]] += 1
total_comparisons += 1
save_json_via_sftp(ratings, comparison_counts, total_comparisons)