Skip to content
Snippets Groups Projects
Commit e8b8e146 authored by Bastien Batardière's avatar Bastien Batardière
Browse files

write the right portions of file

parent 3e4aac90
No related branches found
No related tags found
No related merge requests found
...@@ -411,37 +411,74 @@ def plot_ellipse(mean_x, mean_y, cov, ax): ...@@ -411,37 +411,74 @@ def plot_ellipse(mean_x, mean_y, cov, ax):
return pearson return pearson
def get_simulated_count_data(n=100, p=25, rank=25, d=1, return_true_param=False): def get_components_simulation(dim, rank):
true_beta = torch.randn(d + 1, p, device=DEVICE) block_size = dim // rank
C = torch.randn(p, rank, device=DEVICE) / 5 prev_state = torch.random.get_rng_state()
O = torch.ones((n, p), device=DEVICE) / 2 torch.random.manual_seed(0)
covariates = torch.randn((n, d), device=DEVICE) components = torch.zeros(dim, rank)
true_Sigma = torch.matmul(C, C.T) for column_number in range(rank):
Y, _, _ = sample_PLN(C, true_beta, covariates, O) components[
column_number * block_size : (column_number + 1) * block_size, column_number
] = 1
components += torch.randn(dim, rank) / 8
torch.random.set_rng_state(prev_state)
return components.to(DEVICE)
def get_simulation_offsets_cov_coef(n_samples, nb_cov, dim):
prev_state = torch.random.get_rng_state()
torch.random.manual_seed(0)
if nb_cov < 2:
covariates = None
else:
covariates = torch.randint(
low=-1,
high=2,
size=(n_samples, nb_cov - 1),
dtype=torch.float64,
device=DEVICE,
)
coef = torch.randn(nb_cov, dim, device=DEVICE)
offsets = torch.randint(
low=0, high=2, size=(n_samples, dim), dtype=torch.float64, device=DEVICE
)
torch.random.set_rng_state(prev_state)
return offsets, covariates, coef
def get_simulated_count_data(
n_samples=100, dim=25, rank=5, nb_cov=1, return_true_param=False, seed=0
):
components = get_components_simulation(dim, rank)
offsets, cov, true_coef = get_simulation_offsets_cov_coef(n_samples, nb_cov, dim)
true_covariance = torch.matmul(components, components.T)
counts, _, _ = sample_pln(components, true_coef, cov, offsets, seed=seed)
if return_true_param is True: if return_true_param is True:
return Y, covariates, O, true_Sigma, true_beta return counts, cov, offsets, true_covariance, true_coef
return Y, covariates, O return counts, cov, offsets
def get_real_count_data(n=270, p=100): def get_real_count_data(n_samples=270, dim=100):
if n > 297: if n_samples > 297:
warnings.warn( warnings.warn(
f"\nTaking the whole 270 samples of the dataset. Requested:n={n}, returned:270" f"\nTaking the whole 270 samples of the dataset. Requested:n_samples={n_samples}, returned:270"
) )
n = 270 n_samples = 270
if p > 100: if dim > 100:
warnings.warn( warnings.warn(
f"\nTaking the whole 100 variables. Requested:p={p}, returned:100" f"\nTaking the whole 100 variables. Requested:dim={dim}, returned:100"
) )
dim = 100 dim = 100
Y = pd.read_csv("../example_data/real_data/Y_mark.csv").values[:n, :p] counts = pd.read_csv("../example_data/real_data/Y_mark.csv").values[
print(f"Returning dataset of size {Y.shape}") :n_samples, :dim
return Y ]
print(f"Returning dataset of size {counts.shape}")
return counts
def closest(lst, K): def closest(lst, element):
lst = np.asarray(lst) lst = np.asarray(lst)
idx = (np.abs(lst - K)).argmin() idx = (np.abs(lst - element)).argmin()
return lst[idx] return lst[idx]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment