Skip to content

Commit 18dae8e

Browse files
fschlimbGitHub Enterprise
authored andcommitted
Merge pull request IntelPython#6 from SAT/order_fix
Order fix TC is happy
2 parents 1d92481 + 843ab55 commit 18dae8e

File tree

2 files changed

+28
-28
lines changed

2 files changed

+28
-28
lines changed

generator/gen_daal4py.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def __init__(self, name):
9393
self.steps = set()
9494
self.children = set()
9595

96+
###############################################################################
97+
def ignored(ns, a=None):
98+
return ns in ignore and ((a != None and a in ignore[ns]) or (a == None and not ignore[ns]))
99+
96100

97101
###############################################################################
98102
###############################################################################
@@ -150,7 +154,7 @@ def read(self):
150154
for filename in filenames:
151155
if filename.endswith('.h') and not 'neural_networks' in dirpath and not any(filename.endswith(x) for x in cython_interface.ignore_files):
152156
fname = jp(dirpath,filename)
153-
print('reading ' + fname)
157+
#print('reading ' + fname)
154158
with open(fname, "r") as header:
155159
parsed_data = parse_header(header, cython_interface.ignores)
156160

@@ -234,7 +238,7 @@ def get_all_attrs(self, ns, cls, attr, ons=None):
234238
tmp = getattr(self.namespace_dict[ns].classes[cls], attr)
235239
for a in tmp:
236240
n = a if '::' in a else ns + '::' + a
237-
if ons not in ignore or n not in ignore[ons]:
241+
if not ignored(ons, n):
238242
pmembers[n] = tmp[a]
239243
for parent in self.namespace_dict[ns].classes[cls].parent:
240244
parentclass = cls
@@ -247,7 +251,7 @@ def get_all_attrs(self, ns, cls, attr, ons=None):
247251
pms = self.get_all_attrs(pns, parentclass, attr, ons)
248252
for x in pms:
249253
# ignore duplicates from parents
250-
if (ons not in ignore or x not in ignore[ons]) and not any(x == y for y in pmembers):
254+
if not ignored(ons, x) and not any(x == y for y in pmembers):
251255
pmembers[x] = pms[x]
252256
return pmembers
253257

@@ -390,12 +394,12 @@ def get_expand_attrs(self, ns, cls, attr):
390394
assert ins in self.namespace_dict
391395
assert inp in self.namespace_dict[ins].enums
392396
hlt = self.to_hltype(ns, attrs[i])
393-
if ns in ignore and '::'.join([ins, inp]) in ignore[ns]:
397+
if ignored(ns, '::'.join([ins, inp])):
394398
continue
395399
if hlt:
396400
if hlt[1] in ['stdtype', 'enum', 'class']:
397401
for e in self.namespace_dict[ins].enums[inp]:
398-
if not any(e in x for x in explist) and (ins not in ignore or e not in ignore[ins]):
402+
if not any(e in x for x in explist) and not ignored(ins, e):
399403
explist.append((ins, e, hlt[0]))
400404
else:
401405
print("// Warning: ignoring " + ns + " " + str(hlt))
@@ -460,7 +464,7 @@ def prepare_modelmaps(self, ns, mname='Model'):
460464
###############################################################################
461465
def expand_typedefs(self, ns):
462466
"""
463-
We expand all typedefs in classes/namespaces wihtout recursing
467+
We expand all typedefs in classes/namespaces without recursing
464468
to outer scopes or namespaces.
465469
"""
466470
def expand_td(typedefs):
@@ -604,7 +608,7 @@ def prepare_hlwrapper(self, ns, mode, func):
604608
jparams['params_opt'] = OrderedDict()
605609
for p in parms:
606610
pns, tmp = splitns(p)
607-
if not tmp.startswith('_') and (pns not in ignore or tmp not in ignore[pns]):
611+
if not tmp.startswith('_') and not ignored(pns, tmp):
608612
hlt = self.to_hltype(pns, parms[p])
609613
if hlt and hlt[1] in ['stdtype', 'enum', 'class']:
610614
(hlt, hlt_type, hlt_ns) = hlt
@@ -647,7 +651,7 @@ def prepare_hlwrapper(self, ns, mode, func):
647651
reqi = 0
648652
for ins, iname, itype in expinputs[0]:
649653
tmpi = iname
650-
if tmpi and (ns not in ignore or tmpi not in ignore[ns]):
654+
if tmpi and not ignored(ns, tmpi):
651655
if ns in defaults and tmpi in defaults[ns]:
652656
i = len(tmp_iargs_decl)
653657
dflt = ' = ' + defaults[ns][tmpi]
@@ -717,28 +721,23 @@ def hlapi(self, algo_patterns):
717721
algos = [x for x in self.namespace_dict if any(y in x for y in algo_patterns)] if algo_patterns else self.namespace_dict
718722
algos = [x for x in algos if not any(y in x for y in ['quality_metric', 'transform'])]
719723

720-
# we first extract and prepare the data (input, parameters, results, template spec)
721-
# some algo need to combine several configs, like kmeans needs kmeans::init
724+
# First expand typedefs
722725
for ns in algos + ['algorithms::classifier', 'algorithms::linear_model',]:
723-
# expand typedefs
724726
self.expand_typedefs(ns)
725-
if not ns.startswith('algorithms::neural_networks'):
726-
if not any(ns.endswith(x) for x in ['objective_function', 'iterative_solver']):
727-
nn = ns.split('::')
728-
if nn[0] == 'daal':
729-
if nn[1] == 'algorithms':
730-
func = '_'.join(nn[2:])
731-
else:
732-
func = '_'.join(nn[1:])
733-
elif nn[0] == 'algorithms':
734-
func = '_'.join(nn[1:])
727+
# Next, extract and prepare the data (input, parameters, results, template spec)
728+
for ns in algos + ['algorithms::classifier', 'algorithms::linear_model',]:
729+
if not ignored(ns):
730+
nn = ns.split('::')
731+
if nn[0] == 'daal':
732+
if nn[1] == 'algorithms':
733+
func = '_'.join(nn[2:])
735734
else:
736-
func = '_'.join(nn)
737-
#if any(ns.endswith(x) for x in ['prediction', 'training', 'init', 'transform']):
738-
# func = '_'.join(tmp[1:])
739-
#else:
740-
# func = tmp[-1]
741-
algoconfig.update(self.prepare_hlwrapper(ns, 'Batch', func))
735+
func = '_'.join(nn[1:])
736+
elif nn[0] == 'algorithms':
737+
func = '_'.join(nn[1:])
738+
else:
739+
func = '_'.join(nn)
740+
algoconfig.update(self.prepare_hlwrapper(ns, 'Batch', func))
742741

743742
# and now we can finally generate the code
744743
wg = wrapper_gen(algoconfig, {cpp2hl(i): ifaces[i] for i in ifaces})

generator/wrappers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@
111111
'correctionPairs', 'correctionIndices', 'averageArgumentLIterations',],
112112
'algorithms::optimization_solver::adagrad': ['optionalArgument', 'algorithms::optimization_solver::iterative_solver::OptionalResultId',
113113
'gradientSquareSum'],
114-
'algorithms::optimization_solver::objective_function': ['argument',],
114+
'algorithms::optimization_solver::objective_function': [],
115+
'algorithms::optimization_solver::iterative_solver': [],
115116
}
116117

117118
# List of InterFaces, classes that can be arguments to other algorithms

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy