1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157#!/usr/bin/env python
import sys
import re
import os
import readline # NOQA
import ply.lex
import ply.cpp
# Python 2-3 compatability.
try:
input = raw_input
except NameError:
pass
try:
os_replace = os.replace
except AttributeError:
def os_replace(src, dst):
if os.name == "nt":
os.unlink(dst)
os.rename(src, dst)
IWYU_PATH = os.path.dirname(os.path.realpath(__file__))
VALID_WILDCARD = re.compile(r"(?:[*a-zA-Z_][*a-zA-Z0-9_]*\s*::\s*)*[*a-zA-Z_][*a-zA-Z0-9_]*")
IDENTIFIERS = re.compile(r"\b(?:[a-zA-Z_][a-zA-Z0-9_]*\s*::\s*)*[a-zA-Z_][a-zA-Z0-9_]*")
def validate_wildcard(wildcard, linenr=0):
wildcard = wildcard.strip()
if not re.match(VALID_WILDCARD, wildcard):
raise RuntimeError("invalid wildcard on line {}".format(linenr + 1))
return wildcard.replace("*", ".*")
def make_wildcards_regex(wildcards):
return re.compile("^{}$".format("|".join("(?:(){})".format(w[1]) for w in wildcards)))
def load_database():
headers = {}
wildcards = []
with open(os.path.join(IWYU_PATH, "header_database.txt")) as db:
for linenr, line in enumerate(db):
line = line.strip()
if line.startswith('?'):
wildcards.append(["?", validate_wildcard(line[1:], linenr)])
elif line.startswith('!'):
wildcards.append(["!", validate_wildcard(line[1:], linenr)])
elif '=' in line:
lhs, rhs = line.split('=')
if '*' in lhs:
wildcards.append(["=", validate_wildcard(lhs, linenr), rhs.strip()])
else:
headers[lhs.strip()] = rhs.strip()
elif line:
raise RuntimeError("database format error on line {}".format(linenr=1))
wildcards.sort(key=lambda w: w[0] == "?")
return headers, wildcards
def store_database(headers, wildcards):
tmp = os.path.join(IWYU_PATH, "header_database.txt~")
dst = os.path.join(IWYU_PATH, "header_database.txt")
with open(tmp, "w") as db:
wildcards = sorted(wildcards, key=lambda w: w[0] != "?")
for w in wildcards:
if w[0] in "?!":
db.write("{} {}\n".format(w[0], w[1].replace('.*', '*')))
else:
db.write("{} = {}\n".format(w[1].replace('.*', '*'), w[2]))
for k, v in sorted(headers.items(), key=lambda kv: kv[::-1]):
db.write("{} = {}\n".format(k, headers[k]))
os_replace(tmp, dst)
def get_identifiers(code):
lex = ply.lex.lex(module=ply.cpp)
lex.input(ply.cpp.trigraph(data))
tokens = list(lex)
i = 0
while i < len(tokens):
if tokens[i].type != "CPP_ID":
i += 1
continue
ident = tokens[i].value
i += 1
while "".join(t.type for t in tokens[i:i+3]) == "::CPP_ID":
ident += "::" + tokens[i+2].value
i += 3
yield ident
if len(sys.argv) <= 1:
print("Usage: {} <file>...".format(sys.argv[0]))
sys.exit(0)
headers, wildcards = load_database()
store_database(headers, wildcards)
wildcards_regex = make_wildcards_regex(wildcards)
for filename in sys.argv[1:]:
with open(filename) as f:
data = f.read()
print("Includes for {}:".format(filename))
headers_needed = set()
identifiers_handled = set()
for identifier in get_identifiers(data):
if identifier in identifiers_handled:
continue
if identifier in headers:
headers_needed.add(headers[identifier])
identifiers_handled.add(identifier)
continue
m = re.match(wildcards_regex, identifier)
if m:
identifiers_handled.add(identifier)
wildcard = wildcards[m.lastindex - 1]
if wildcard[0] == "!":
continue
elif wildcard[0] == "=":
headers_needed.add(wildcard[2])
else:
sys.stderr.write(identifier + ": ")
header = input().strip()
if header == "!":
wildcards.append(["!", identifier])
wildcards_regex = make_wildcards_regex(wildcards)
else:
headers[identifier] = header
headers_needed.add(header)
store_database(headers, wildcards)
# Key is a hack to make " sort after <.
for header_needed in sorted(headers_needed, key=lambda h: h.replace('"', "~")):
print("#include {}".format(header_needed))
sys.stdout.write("\n")
sys.stdout.flush()
store_database(headers, wildcards)