-
Notifications
You must be signed in to change notification settings - Fork 0
/
smt_solver.py
131 lines (118 loc) · 4.29 KB
/
smt_solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from heuristics import row_packing_partition
from tools import MyMatrix, RectangularPartition
import z3
def smt_euf_partition(A, known_solution=None, if_print=False):
if known_solution:
partition = known_solution
else:
partition = row_packing_partition(A, 10)
if if_print:
print("-------------------heuristic result-------------------")
RectangularPartition(A, partition).visualize()
print("------------------------------------------------------")
b = len(partition) - 1
mat = MyMatrix(A)
real_rank = mat.real_rank()
if real_rank == len(partition):
return partition
(M, N) = mat.dimensions()
# construct SAT
solver = z3.Solver()
vars = {}
num1 = 0
for i in range(M):
for j in range(N):
if A[i][j] == 1:
vars[(i,j)] = num1
num1 += 1
f = z3.Function(
'rect',
z3.BitVecSort(max(1, num1.bit_length())),
z3.BitVecSort(max(1, b.bit_length()))
)
for i in range(num1):
solver.add(z3.ULT(f(i), b))
for i in range(M):
for j in range(N):
for ii in range(M):
for jj in range(N):
if i != ii and j != jj and A[i][j] == 1 and A[ii][jj] == 1:
if A[i][jj] == 0:
solver.add(
z3.Not(f(vars[(i,j)]) == f(vars[(ii,jj)]))
)
if A[i][jj] == 1:
solver.add(
z3.Implies(
f(vars[(i,j)]) == f(vars[(ii,jj)]),
f(vars[(i,j)]) == f(vars[(i,jj)])
)
)
while b >= real_rank:
if if_print:
print(f"-------------------trying rank={b} with SAT")
check_result = solver.check()
if check_result == z3.unsat:
if if_print:
print(f"-------------------rank={b} UNSAT")
break
elif check_result == z3.sat:
model = solver.model()
# readout solution
products = []
for k in range(b):
rows = []
cols = []
for i in range(M):
for j in range(N):
if A[i][j] == 1:
if model.evaluate(f(vars[(i,j)])) == k:
if i not in rows:
rows.append(i)
if j not in cols:
cols.append(j)
products.append({"rows": rows, "cols": cols})
partition = products
if if_print:
RectangularPartition(A, partition).visualize()
b -= 1
# narrow down the solution space
for i in range(num1):
solver.add(z3.ULT(f(i), b))
else:
raise ValueError("z3 not returning")
return partition
def fooling_set(A, b):
# check if A has fooling set with size >= b
M, N = MyMatrix(A).dimensions()
x = [[z3.Bool(f'in_row{m}_col{n}') for n in range(N)] for m in range(M)]
solver = z3.Solver()
for i in range(M):
for j in range(N):
if A[i][j] == 0:
solver.add(z3.Not(x[i][j]))
for i in range(M):
for j in range(N):
for ii in range(M):
for jj in range(N):
if A[i][j] == 1 and A[ii][jj] == 1:
if i == ii and j == jj:
continue
if i == ii or j == jj:
solver.add(
z3.Or( z3.Not(x[i][j]), z3.Not(x[ii][jj]))
)
continue
if A[i][jj] ==1 and A[ii][j] == 1:
solver.add(
z3.Or( z3.Not(x[i][j]), z3.Not(x[ii][jj]))
)
solver.add(
b <= sum(
[z3.If(x[i][j], 1, 0) for i in range(M) for j in range(N)]
)
)
if solver.check() == z3.sat:
return True
else:
return False