Flink runner not splitting tasks when parallelism is turned on in BEAM python pipeline - apache-flink

I have a beam pipeline written in python that when deployed to a flink runner doesn't make use of the parallelism correctly.
There is unbounded data coming in through a kafka connector and I want the data to be read when split in parallel.
My understanding is that it should split up the tasks but as shown in the image one parallelism is used and all the other 5 sub tasks finished instantly leaving the one running to do all the work.
The pipeline settings are:
options = PipelineOptions([
"--runner=PortableRunner",
"--sdk_worker_parallelism=3",
"--artifact_endpoint=localhost:8098",
"--job_endpoint=localhost:8099",
"--environment_type=EXTERNAL",
"--environment_config=localhost:50000",
"--checkpointing_interval=30000",
])
options._all_options['parallelism'] = 3
Is this a missing config on the Flink runner or something that can be configured in the BEAM pipeline?
The full pipeline:
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
options = PipelineOptions([
"--runner=PortableRunner",
"--sdk_worker_parallelism=3",
"--artifact_endpoint=localhost:8098",
"--job_endpoint=localhost:8099",
"--environment_type=EXTERNAL",
"--environment_config=localhost:50000",
"--checkpointing_interval=30000",
])
options._all_options['parallelism'] = 3
class CountProvider(beam.RestrictionProvider):
def __init__(self, initial_split_size=5):
self._initial_split_size = initial_split_size
self.OffsetRestrictionTracker = None
def imports(self):
if self.OffsetRestrictionTracker is not None: return
from apache_beam.io.restriction_trackers import OffsetRestrictionTracker, OffsetRange
self.OffsetRestrictionTracker = OffsetRestrictionTracker
self.OffsetRange = OffsetRange
def initial_restriction(self, element):
self.imports()
return self.OffsetRange(0, 10)
def create_tracker(self, restriction):
self.imports()
return self.OffsetRestrictionTracker(restriction)
def restriction_size(self, element, restriction):
return restriction.size()*100_000
def split(self, element, restriction):
self.imports()
if restriction.start + 1 >= restriction.stop:
yield self.OffsetRange(restriction.start, restriction.stop)
else:
last_val = restriction.start
for i in range(1, self._initial_split_size):
next_stop = i * (restriction.start + restriction.stop) // self._initial_split_size
yield self.OffsetRange(last_val, next_stop)
last_val = next_stop
yield self.OffsetRange(last_val, restriction.stop)
class CountFn(beam.DoFn):
def setup(self):
print("setup")
def process(self, element, tracker=beam.DoFn.RestrictionParam(CountProvider())):
res = tracker.current_restriction()
print(f"Current Restriction {res.start}, {res.stop}")
for i in range(res.start, res.stop):
if not tracker.try_claim(i):
return
for j in range(10_000):
yield i, j
def get_initial_restriction(self, filename):
return (0, 10)
def teardown(self):
print("Teardown")
p = beam.Pipeline(options=options)
out = (p | f'Create' >> beam.Create([tuple()])
| f'Gen Data' >> beam.ParDo(CountFn())
| beam.Map(print)
)
result = p.run()
result.wait_until_finish()

Related

Using v4l2sink with DeepStream

I'm working on deepstream code to pass rtsp streams to virtual V4L2 devices (I used v4l2loopback to create the virtual devices). I have a code that works without errors, however, I can't read the V4L2 device.
Does anyone know of a working DeepStream code where v4l2sink is used? I have tried to find an example without success.
Here is my code. The writing part to v4l2sink is in the function: create_v4l2sink_branch()
import sys
import gi
gi.require_version('Gst', '1.0')
gi.require_version('GstRtspServer', '1.0')
import math
import sys
import common.utils as DS_UTILS
import pyds
from common.bus_call import bus_call
from common.FPS import PERF_DATA
from common.is_aarch_64 import is_aarch64
from gi.repository import GLib, Gst, GstRtspServer
CODEC="H264"
BITRATE=4000000
MAX_DISPLAY_LEN = 64
MUXER_OUTPUT_WIDTH = 1920
MUXER_OUTPUT_HEIGHT = 1080
MUXER_BATCH_TIMEOUT_USEC = 400000
TILED_OUTPUT_WIDTH = 1920
TILED_OUTPUT_HEIGHT = 1080
GST_CAPS_FEATURES_NVMM = "memory:NVMM"
OSD_PROCESS_MODE = 0
OSD_DISPLAY_TEXT = 1
MUX_SYNC_INPUTS = 0
ds_loop=None
perf_data = None
def terminate_pipeline(u_data):
global ds_loop
pass
# if global_config.request_to_stop == True:
# print("Aborting pipeline by request")
# ds_loop.quit()
# return False
return True
def create_onscreen_branch(pipeline, gst_elem, index):
print("Creating EGLSink")
sink = DS_UTILS.create_gst_element("nveglglessink", f"nvvideo-renderer-{index}")
sink.set_property('sync', 0)
sink.set_property('async', 1)
pipeline.add(sink)
if is_aarch64():
transform = DS_UTILS.create_gst_element("nvegltransform", f"nvegl-transform{index}")
pipeline.add(transform)
gst_elem.link(transform)
transform.link(sink)
else:
gst_elem.link(sink)
sink.set_property("qos", 0)
def create_v4l2sink_branch(pipeline, gst_elem, index, output_video_device):
# Create a caps filter
caps = DS_UTILS.create_gst_element("capsfilter", f"filter-{index}")
#caps.set_property("caps", Gst.Caps.from_string("video/x-raw(memory:NVMM), format=I420"))
#caps.set_property("caps", Gst.Caps.from_string("video/x-raw(memory:NVMM), format=NV12"))
identity = DS_UTILS.create_gst_element("identity", f"identity-{index}")
identity.set_property("drop-allocation", 1)
nvvidconv = DS_UTILS.create_gst_element("nvvideoconvert", f"convertor-{index}")
sink = DS_UTILS.create_gst_element("v4l2sink", f"v4l2sink-{index}")
sink.set_property('device', output_video_device)
sink.set_property("sync", 0)
sink.set_property("async", 1)
pipeline.add(caps)
pipeline.add(nvvidconv)
pipeline.add(identity)
pipeline.add(sink)
gst_elem.link(caps)
caps.link(nvvidconv)
nvvidconv.link(identity)
identity.link(sink)
def run_pipeline(rtsp_v4l2_pairs):
# Check input arguments
number_sources = len(rtsp_v4l2_pairs)
perf_data = PERF_DATA(number_sources)
# Standard GStreamer initialization
Gst.init(None)
# Create gstreamer elements */
# Create Pipeline element that will form a connection of other elements
print("Creating Pipeline")
pipeline = Gst.Pipeline()
is_live = False
if not pipeline:
sys.stderr.write(" Unable to create Pipeline \n")
return
# Create nvstreammux instance to form batches from one or more sources.
streammux = DS_UTILS.create_gst_element("nvstreammux", "Stream-muxer")
pipeline.add(streammux)
for i in range(number_sources):
uri_name = rtsp_v4l2_pairs[i][0]
print(" Creating source_bin {} --> {}".format(i, uri_name))
is_live = uri_name.find("rtsp://") == 0
source_bin = DS_UTILS.create_source_bin(i, uri_name)
pipeline.add(source_bin)
padname = "sink_%u" % i
sinkpad = streammux.get_request_pad(padname)
if not sinkpad:
sys.stderr.write("Unable to create sink pad bin \n")
srcpad = source_bin.get_static_pad("src")
if not srcpad:
sys.stderr.write("Unable to create src pad bin \n")
srcpad.link(sinkpad)
# streammux setup
if is_live:
print(" At least one of the sources is live")
streammux.set_property('live-source', 1)
streammux.set_property('width', MUXER_OUTPUT_WIDTH)
streammux.set_property('height', MUXER_OUTPUT_HEIGHT)
streammux.set_property('batch-size', number_sources)
streammux.set_property("batched-push-timeout", MUXER_BATCH_TIMEOUT_USEC)
#streammux.set_property("sync-inputs", MUX_SYNC_INPUTS)
queue = DS_UTILS.create_gst_element("queue", "queue1")
pipeline.add(queue)
nvstreamdemux = DS_UTILS.create_gst_element("nvstreamdemux", "nvstreamdemux")
pipeline.add(nvstreamdemux)
# linking
streammux.link(queue)
queue.link(nvstreamdemux)
for i in range(number_sources):
queue = DS_UTILS.create_gst_element("queue", f"queue{2+i}")
pipeline.add(queue)
demuxsrcpad = nvstreamdemux.get_request_pad(f"src_{i}")
if not demuxsrcpad:
sys.stderr.write("Unable to create demux src pad \n")
queuesinkpad = queue.get_static_pad("sink")
if not queuesinkpad:
sys.stderr.write("Unable to create queue sink pad \n")
demuxsrcpad.link(queuesinkpad)
#create_onscreen_branch(pipeline=pipeline, gst_elem=queue, index=i)
create_v4l2sink_branch(pipeline=pipeline, gst_elem=queue, index=i, output_video_device=rtsp_v4l2_pairs[i][1])
# for termate the pipeline
GLib.timeout_add_seconds(1, terminate_pipeline, 0)
# display FPS
GLib.timeout_add(5000, perf_data.perf_print_callback)
# create an event loop and feed gstreamer bus mesages to it
loop = GLib.MainLoop()
ds_loop = loop
bus = pipeline.get_bus()
bus.add_signal_watch()
bus.connect("message", bus_call, loop)
print("Starting pipeline")
# start play back and listed to events
pipeline.set_state(Gst.State.PLAYING)
try:
loop.run()
except:
pass
# cleanup
print("Pipeline ended")
pipeline.set_state(Gst.State.NULL)
if __name__ == '__main__':
import json
import sys
pairs = [
("rtsp://192.168.1.88:554/22", "/dev/video6")
]
run_pipeline(rtsp_v4l2_pairs=pairs)

Custom Locust User for SageMaker Endpoint Keeps running after time limit is reached

I have been trying to build a SagemakerUser from the base User class in the Locust library. The issue though is when I use it with a timed shape test, when said test ends (you can see a message: Shape test stopping) the load test shrugs it off and continues. Below is the script I have written to this end. My question is how is this behaviour explained?
import pandas as pd
from locust import HttpUser, User, task, TaskSet, events, LoadTestShape
from sagemaker.serializers import JSONSerializer
from sagemaker.session import Session
import sagemaker
import time
import sys
import math
import pdb
df = "some df to load samples from"
endpoint = "sage maker end point name"
class SagemakerClient(sagemaker.predictor.Predictor):
def predictEx(self, data):
start_time = time.time()
start_perf_counter = time.perf_counter()
name = 'predictEx'
try:
result = self.predict(data)
except:
total_time = int((time.perf_counter() - start_perf_counter) * 1000)
events.request_failure.fire(request_type="sagemaker", name=name, response_time=total_time, exception=sys.exc_info(), response_length=0)
else:
total_time = int((time.perf_counter() - start_perf_counter) * 1000)
events.request_success.fire(request_type="sagemaker", name=name, response_time=total_time, response_length=sys.getsizeof(result))
class SagemakerLocust(User):
abstract = True
def __init__(self, *args, **kwargs):
super(SagemakerLocust, self).__init__(*args, **kwargs)
self.client = SagemakerClient(
sagemaker_session = Session(),
endpoint_name = "sagemaker-test",
serializer = JSONSerializer())
class APIUser(SagemakerLocust):
#task
def call(self):
request = df.text.sample(1, weights=df.length).iloc[0]
self.client.predictEx(request)
class StepLoadShape(LoadTestShape):
"""
A step load shape
Keyword arguments:
step_time -- Time between steps
step_load -- User increase amount at each step
spawn_rate -- Users to stop/start per second at every step
time_limit -- Time limit in seconds
"""
step_time = 30#3600
step_load = 1
spawn_rate = 1
time_limit =2#3600*6
#pdb.set_trace()
def tick(self):
run_time = self.get_run_time()
if run_time > self.time_limit:
return None
current_step = math.floor(run_time / self.step_time) + 1
return (current_step * self.step_load, self.spawn_rate)

Sagemaker - batch transform] Internal server error : 500

I am trying to do a batch transform on a training dataset in an S3 bucket. I have followed this link:
https://github.com/aws-samples/quicksight-sagemaker-integration-blog
The training data on which transformation is being applied is of ~35 MB.
I am getting these errors:
Bad HTTP status received from algorithm: 500
The server encountered an internal error and was unable to complete your request. Either the server is overloaded or there is an error in the application.
Process followed:
1. s3_input_train = sagemaker.TrainingInput(s3_data='s3://{}/{}/rawtrain/'.format(bucket, prefix), content_type='csv')
2. from sagemaker.sklearn.estimator import SKLearn
sagemaker_session = sagemaker.Session()
script_path = 'preprocessing.py'
sklearn_preprocessor = SKLearn(
entry_point=script_path,
role=role,
train_instance_type="ml.c4.xlarge",
framework_version='0.20.0',
py_version = 'py3',
sagemaker_session=sagemaker_session)
sklearn_preprocessor.fit({'train': s3_input_train})
3. transform_train_output_path = 's3://{}/{}/{}/'.format(bucket, prefix, 'transformtrain-train-output')
scikit_learn_inferencee_model = sklearn_preprocessor.create_model(env={'TRANSFORM_MODE': 'feature-transform'})
transformer_train = scikit_learn_inferencee_model.transformer(
instance_count=1,
assemble_with = 'Line',
output_path = transform_train_output_path,
accept = 'text/csv',
strategy = "MultiRecord",
max_payload =40,
instance_type='ml.m4.xlarge')
4. Preprocess training input
transformer_train.transform(s3_input_train.config['DataSource']['S3DataSource']['S3Uri'],
content_type='text/csv',
split_type = "Line")
print('Waiting for transform job: ' + transformer_train.latest_transform_job.job_name)
transformer_train.wait()
preprocessed_train_path = transformer_train.output_path + transformer_train.latest_transform_job.job_name
preprocessing.py
from __future__ import print_function
import time
import sys
from io import StringIO
import os
import shutil
import argparse
import csv
import json
import numpy as np
import pandas as pd
import logging
from sklearn.compose import ColumnTransformer
from sklearn.externals import joblib
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import Binarizer, StandardScaler, OneHotEncoder
from sagemaker_containers.beta.framework import (
content_types, encoders, env, modules, transformer, worker)
# Specifying the column names here.
feature_columns_names = [
'A',
'B',
'C',
'D',
'E',
'F',
'G',
'H',
'I',
'J',
'K'
]
label_column = 'ab'
feature_columns_dtype = {
'A' : str,
'B' : np.float64,
'C' : np.float64,
'D' : str,
"E" : np.float64,
'F' : str,
'G' : str,
'H' : np.float64,
'I' : str,
'J' : str,
'K': str,
}
label_column_dtype = {'ab': np.int32}
def merge_two_dicts(x, y):
z = x.copy() # start with x's keys and values
z.update(y) # modifies z with y's keys and values & returns None
return z
def _is_inverse_label_transform():
"""Returns True if if it's running in inverse label transform."""
return os.getenv('TRANSFORM_MODE') == 'inverse-label-transform'
def _is_feature_transform():
"""Returns True if it's running in feature transform mode."""
return os.getenv('TRANSFORM_MODE') == 'feature-transform'
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Sagemaker specific arguments. Defaults are set in the environment variables.
parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
args = parser.parse_args()
input_files = [ os.path.join(args.train, file) for file in os.listdir(args.train) ]
if len(input_files) == 0:
raise ValueError(('There are no files in {}.\n' +
'This usually indicates that the channel ({}) was incorrectly specified,\n' +
'the data specification in S3 was incorrectly specified or the role specified\n' +
'does not have permission to access the data.').format(args.train, "train"))
raw_data = [ pd.read_csv(
file,
header=None,
names=feature_columns_names + [label_column],
dtype=merge_two_dicts(feature_columns_dtype, label_column_dtype)) for file in input_files ]
concat_data = pd.concat(raw_data)
numeric_features = list([
'B',
'C',
'E',
'H'
])
numeric_transformer = Pipeline(steps=[
('imputer', SimpleImputer(strategy='median')),
('scaler', StandardScaler())])
categorical_features = list(['A','D','F','G','I','J','K'])
categorical_transformer = Pipeline(steps=[
('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
('onehot', OneHotEncoder(handle_unknown='ignore'))])
preprocessor = ColumnTransformer(
transformers=[
('num', numeric_transformer, numeric_features),
('cat', categorical_transformer, categorical_features)],
remainder="drop")
preprocessor.fit(concat_data)
joblib.dump(preprocessor, os.path.join(args.model_dir, "model.joblib"))
print("saved model!")
def input_fn(input_data, request_content_type):
"""Parse input data payload
We currently only take csv input. Since we need to process both labelled
and unlabelled data we first determine whether the label column is present
by looking at how many columns were provided.
"""
content_type = request_content_type.lower(
) if request_content_type else "text/csv"
content_type = content_type.split(";")[0].strip()
if isinstance(input_data, str):
str_buffer = input_data
else:
str_buffer = str(input_data,'utf-8')
if _is_feature_transform():
if content_type == 'text/csv':
# Read the raw input data as CSV.
df = pd.read_csv(StringIO(input_data), header=None)
if len(df.columns) == len(feature_columns_names) + 1:
# This is a labelled example, includes the label
df.columns = feature_columns_names + [label_column]
elif len(df.columns) == len(feature_columns_names):
# This is an unlabelled example.
df.columns = feature_columns_names
return df
else:
raise ValueError("{} not supported by script!".format(content_type))
if _is_inverse_label_transform():
if (content_type == 'text/csv' or content_type == 'text/csv; charset=utf-8'):
# Read the raw input data as CSV.
df = pd.read_csv(StringIO(str_buffer), header=None)
if len(df.columns) == len(feature_columns_names) + 1:
# This is a labelled example, includes the ring label
df.columns = feature_columns_names + [label_column]
elif len(df.columns) == len(feature_columns_names):
# This is an unlabelled example.
df.columns = feature_columns_names
return df
else:
raise ValueError("{} not supported by script!".format(content_type))
def output_fn(prediction, accept):
"""Format prediction output
The default accept/content-type between containers for serial inference is JSON.
We also want to set the ContentType or mimetype as the same value as accept so the next
container can read the response payload correctly.
"""
accept = 'text/csv'
if type(prediction) is not np.ndarray:
prediction=prediction.toarray()
if accept == "application/json":
instances = []
for row in prediction.tolist():
instances.append({"features": row})
json_output = {"instances": instances}
return worker.Response(json.dumps(json_output), mimetype=accept)
elif accept == 'text/csv':
return worker.Response(encoders.encode(prediction, accept), mimetype=accept)
else:
raise RuntimeException("{} accept type is not supported by this script.".format(accept))
def predict_fn(input_data, model):
"""Preprocess input data
We implement this because the default predict_fn uses .predict(), but our model is a preprocessor
so we want to use .transform().
The output is returned in the following order:
rest of features either one hot encoded or standardized
"""
if _is_feature_transform():
features = model.transform(input_data)
if label_column in input_data:
# Return the label (as the first column) and the set of features.
return np.insert(features.toarray(), 0, pd.get_dummies(input_data[label_column])['True.'], axis=1)
else:
# Return only the set of features
return features
if _is_inverse_label_transform():
features = input_data.iloc[:,0]>0.5
features = features.values
return features
def model_fn(model_dir):
"""Deserialize fitted model
"""
if _is_feature_transform():
preprocessor = joblib.load(os.path.join(model_dir, "model.joblib"))
return preprocessor
Please help.
As I can see you are referring to a post that is pretty old and has issues open on Github here with regards to the Input source and the configurations.
I encourage you to check out the latest examples here which shows how to visualize Amazon SageMaker machine learning predictions with Amazon QuickSight.
Additionally, if the problem persists please open a service request with AWS Support with job ARN to investigate future on the Internal Server Error.

WatsonApiException: Error: invalid-api-key, Code: 401

I cant find Alchemy Language API in IBM Watson.
Can I do this with natural-language-understanding service and how?
When I add
from watson_developer_cloud import NaturalLanguageUnderstandingV1
from watson_developer_cloud.natural_language_understanding_v1 \
import Features, EntitiesOptions, KeywordsOptions
It shows some error with combined keyword
# In[]:
import tweepy
import re
import time
import math
import pandas as pd
from watson_developer_cloud import AlchemyLanguageV1
def initAlchemy():
al = AlchemyLanguageV1(api_key='GRYVUMdBbOtJXxNOIs1aopjjaiyOmLG7xJBzkAnvvwLh')
return al
def initTwitterApi():
consumer_key = 'OmK1RrZCVJSRmKxIuQqkBExvw'
consumer_secret = 'VWn6OR4rRgSi7qGnZHCblJMhrSvj1QbJmf0f62uX6ZQWZUUx5q'
access_token = '4852231552-adGooMpTB3EJYPHvs6oGZ40qlo3d2JbVjqUUWkJ'
access_token_secret = 'm9hgeM9p0r1nn8IoQWJYBs5qUQu56XmrAhsDSYKjuiVA4'
auth = tweepy.OAuthHandler(consumer_key, consumer_secret)
auth.set_access_token(access_token, access_token_secret)
api = tweepy.API(auth)
return api
'''This function is implemented to handle tweepy exception errors
because search is rate limited at 180 queries per 15 minute window by twitter'''
def limit(cursor):
while True:
try:
yield cursor.next()
except tweepy.TweepError as error:
print(repr(error))
print("Twitter Request limit error reached sleeping for 15 minutes")
time.sleep(16*60)
except tweepy.RateLimitError:
print("Rate Limit Error occurred Sleeping for 16 minutes")
time.sleep(16*60)
def retrieveTweets(api, search, lim):
if(lim == ""):
lim = math.inf
else:
lim = int(lim)
text = []
for tweet in limit(tweepy.Cursor(api.search, q=search).items(limit = lim)):
t = re.sub('\s+', ' ', tweet.text)
text.append(t)
data = {"Tweet":text,
"Sentiment":"",
"Score":""}
dataFrame = pd.DataFrame(data, columns=["Tweet","Sentiment","Score"])
return dataFrame
def analyze(al,dataFrame):
sentiment = []
score = []
for i in range(0, dataFrame["Tweet"].__len__()):
res = al.combined(text=dataFrame["Tweet"][i],
extract="doc-sentiment",
sentiment=1)
sentiment.append(res["docSentiment"]["type"])
if(res["docSentiment"]["type"] == "neutral"):
score.append(0)
else:
score.append(res["docSentiment"]["score"])
dataFrame["Sentiment"] = sentiment
dataFrame["Score"] = score
return dataFrame
def main():
#Initialse Twitter Api
api = initTwitterApi()
#Retrieve tweets
dataFrame = retrieveTweets(api,input("Enter the search query (e.g. #hillaryclinton ) : "), input("Enter limit for number of tweets to be searched or else just hit enter : "))
#Initialise IBM Watson Alchemy Language Api
al = initAlchemy()
#Do Document Sentiment analysis
dataFrame = analyze(al, dataFrame)
#Save tweets, sentiment, and score data frame in csv file
dataFrame.to_csv(input("Enter the name of the file (with .csv extension) : "))
if __name__ == '__main__':
main()# -*- coding: utf-8 -*-
The Watson Natural Language Understanding only has a combined call, but since it is the only call, it isn't called combined, its actually analyze. Best place to go for details would be the API documentation - https://www.ibm.com/watson/developercloud/natural-language-understanding/api/v1/?python#post-analyze

Wipe out dropout operations from TensorFlow graph

I have a trained freezed graph that I am trying to run on an ARM device. Basically, I am using contrib/pi_examples/label_image, but with my network instead of Inception. My network was trained with dropout, which now causes me troubles:
Invalid argument: No OpKernel was registered to support Op 'Switch' with these attrs. Registered kernels:
device='CPU'; T in [DT_FLOAT]
device='CPU'; T in [DT_INT32]
device='GPU'; T in [DT_STRING]
device='GPU'; T in [DT_BOOL]
device='GPU'; T in [DT_INT32]
device='GPU'; T in [DT_FLOAT]
[[Node: l_fc1_dropout/cond/Switch = Switch[T=DT_BOOL](is_training_pl, is_training_pl)]]
One solution I can see is to build such TF static library that includes the corresponding operation. From other hand, it might be a better idea to eliminate the dropout ops from the network in order to make it simpler and faster. Is there a way to do that?
Thanks.
#!/usr/bin/env python2
import argparse
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
def print_graph(input_graph):
for node in input_graph.node:
print "{0} : {1} ( {2} )".format(node.name, node.op, node.input)
def strip(input_graph, drop_scope, input_before, output_after, pl_name):
input_nodes = input_graph.node
nodes_after_strip = []
for node in input_nodes:
print "{0} : {1} ( {2} )".format(node.name, node.op, node.input)
if node.name.startswith(drop_scope + '/'):
continue
if node.name == pl_name:
continue
new_node = node_def_pb2.NodeDef()
new_node.CopyFrom(node)
if new_node.name == output_after:
new_input = []
for node_name in new_node.input:
if node_name == drop_scope + '/cond/Merge':
new_input.append(input_before)
else:
new_input.append(node_name)
del new_node.input[:]
new_node.input.extend(new_input)
nodes_after_strip.append(new_node)
output_graph = graph_pb2.GraphDef()
output_graph.node.extend(nodes_after_strip)
return output_graph
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--input-graph', action='store', dest='input_graph')
parser.add_argument('--input-binary', action='store_true', default=True, dest='input_binary')
parser.add_argument('--output-graph', action='store', dest='output_graph')
parser.add_argument('--output-binary', action='store_true', dest='output_binary', default=True)
args = parser.parse_args()
input_graph = args.input_graph
input_binary = args.input_binary
output_graph = args.output_graph
output_binary = args.output_binary
if not tf.gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return
input_graph_def = tf.GraphDef()
mode = "rb" if input_binary else "r"
with tf.gfile.FastGFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read().decode("utf-8"), input_graph_def)
print "Before:"
print_graph(input_graph_def)
output_graph_def = strip(input_graph_def, u'l_fc1_dropout', u'l_fc1/Relu', u'prediction/MatMul', u'is_training_pl')
print "After:"
print_graph(output_graph_def)
if output_binary:
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
else:
with tf.gfile.GFile(output_graph, "w") as f:
f.write(text_format.MessageToString(output_graph_def))
print("%d ops in the final graph." % len(output_graph_def.node))
if __name__ == "__main__":
main()
How about this as a more general solution:
for node in temp_graph_def.node:
for idx, i in enumerate(node.input):
input_clean = node_name_from_input(i)
if input_clean.endswith('/cond/Merge') and input_clean.split('/')[-3].startswith('dropout'):
identity = node_from_map(input_node_map, i).input[0]
assert identity.split('/')[-1] == 'Identity'
parent = node_from_map(input_node_map, node_from_map(input_node_map, identity).input[0])
pred_id = parent.input[1]
assert pred_id.split('/')[-1] == 'pred_id'
good = parent.input[0]
node.input[idx] = good

Resources