from pysat.solvers import Minisat22
import itertools
import sys

n_student = 15
n_day = (n_student-1)//2
n_group = n_student//3

idx2 = 1
varmap2 = dict()
mapback2 = dict()

for d in range(n_day):
    for g in range(n_group):
        for i in range(n_student):
            varmap2[(d, g, i)] = idx2
            mapback2[idx2] = (d, g, i)
            idx2 += 1

for d in range(n_day):
    for g in range(n_group):
        for (i, j) in itertools.combinations(range(n_student), 2):
            varmap2[(d, g, (i, j))] = idx2
            mapback2[idx2] = (d, g, (i, j))
            idx2 += 1

def link_symbols(solver, ivarmap):
    for d in range(n_day):
        for g in range(n_group):
            for (i, j) in itertools.combinations(range(n_student), 2):
                solver.add_clause([-ivarmap[(d, g, i)], -ivarmap[(d, g, j)], ivarmap[(d, g, (i, j))]])
                solver.add_clause([-ivarmap[(d, g, (i, j))], ivarmap[(d, g, i)]])
                solver.add_clause([-ivarmap[(d, g, (i, j))], ivarmap[(d, g, j)]])
 
# each pair of students only walk once
def pair_once(solver, ivarmap):
    for (i, j) in itertools.combinations(range(n_student), 2):
        each_pair = [ivarmap[(d, g, (i, j))] for d in range(n_day) for g in range(n_group)]
        for (c1, c2) in itertools.combinations(each_pair, 2):
            solver.add_clause([-c1, -c2])
        solver.add_clause(each_pair)

# forces only 3 students walk in one group together with other constraints
def triangle(solver, ivarmap):
    for d in range(n_day):
        for g in range(n_group):
            for (c1, c2, c3) in itertools.combinations(range(n_student), 3):
                solver.add_clause([-ivarmap[(d, g, (c1, c2))], -ivarmap[(d, g, (c1, c3))], ivarmap[(d, g, (c2, c3))]])

# a student walk in one and only one group on a day
def one_group(solver, ivarmap):
    for d in range(n_day):
        for i in range(n_student):
            one_student = [ivarmap[(d, g, i)] for g in range(n_group)]
            for (c1, c2) in itertools.combinations(one_student, 2):
                solver.add_clause([-c1, -c2])
            solver.add_clause(one_student)

def sol_4B():
    s = Minisat22()
    pair_once(s, varmap2)
    triangle(s, varmap2)
    one_group(s, varmap2)
    link_symbols(s, varmap2)
    s.solve()
    return s.get_model()

def sol_4C():
    s2 = Minisat22()
    pair_once(s2, varmap2)
    triangle(s2, varmap2)
    one_group(s2, varmap2)
    link_symbols(s2, varmap2)
    res = []
    for _ in range(10):
        s2.solve()
        res.append(s2.get_model())
        s2.add_clause([-x for x in res[-1]][:n_student * n_group * n_day])
    return res

# this is to teach students how to print the truth assignments in a way we can understand
# def print_sol(imodel):
#     tot = n_student * n_group * n_day
#     for var in imodel[:tot]:
#         if var > 0:
#             print(mapback2[abs(var)])
#     print("=" * 100)

# print_sol(sol_4B())
# sols = sol_4C()
# for sol in sols:
#     print_sol(sol)

# checking if two sols are the same
# import numpy as np
# for (i, j) in itertools.combinations(range(10), 2):
#     a = np.array(sols[i][:n_student * n_group * n_day])
#     b = np.array(sols[j][:n_student * n_group * n_day])
#     print(np.all(a-b==0))

