import numpy as np
import matplotlib.pyplot as plt
import sys
chain_num = int(sys.argv[1])
interval = int(sys.argv[2])
data = np.loadtxt(str(chain_num)+"_chain.txt",delimiter=',')
M = chain_num

maxN = int(data.shape[1]/2)
def shrink_factor(N):
    twoN = 2*N
    twoNdata = data[:,:twoN]
    subdata = twoNdata[:,N:]
    W = np.sum(np.var(subdata,axis=1))/M
    theta = np.mean(subdata,axis=1)
    thetabar = np.mean(theta)
    B=N*np.sum((theta-thetabar)**2)/(M-1)
    vartheta = (1-1/N)*W+B/N
    r = np.sqrt(vartheta/W)
    print(str(N) + '/' + str(maxN) + '\r', end = '')
    return r





results = [shrink_factor(N) for N in range(interval,maxN,interval)]
np.save("shrink_factors_"+str(chain_num)+"_chain.npy",results)
plt.plot(results)
plt.show()
