Machine Learning in Production: Best Practices
Introduction
Deploying machine learning models to production is fundamentally different from training them in a notebook. Production ML systems need to be reliable, scalable, maintainable, and continuously improving.
Production Architecture
1. Model Serving
REST API Approach
from flask import Flask, request, jsonify
import joblib
import numpy as np
app = Flask(__name__)
<h1 id="load trained model" class="heading-1">Load trained model</h1>
model = joblib.load('model.pkl')
@app.route('/predict', methods=['POST'])
def predict():
try:
data = request.get_json()
features = np.array(data['features']).reshape(1, -1)
prediction = model.predict(features)[0]
confidence = model.predict_proba(features)[0].max()
return jsonify({
'prediction': int(prediction),
'confidence': float(confidence),
'timestamp': datetime.now().isoformat()
})
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
2. Model Versioning
Semantic Versioning
Implement proper model versioning:
class ModelRegistry:
def __init__(self):
self.models = {}
self.current_versions = {}
def register_model(self, name: str, model_path: str, version: str, metrics: dict):
self.models[f"{name}:{version}"] = {
'path': model_path,
'metrics': metrics,
'created_at': datetime.now()
}
def promote_model(self, name: str, version: str):
self.current_versions[name] = version
# Update load balancer to point to new version
self.update_load_balancer(name, version)
def rollback_model(self, name: str):
previous_version = self.get_previous_version(name)
if previous_version:
self.promote_model(name, previous_version)
3. Monitoring and Observability
Model Performance Monitoring
class ModelMonitor:
def __init__(self, model_name: str):
self.model_name = model_name
self.prediction_buffer = []
self.drift_detector = DataDriftDetector()
def log_prediction(self, input_data, prediction, confidence, ground_truth=None):
prediction_log = {
'timestamp': datetime.now(),
'model': self.model_name,
'input_hash': hash(str(input_data)),
'prediction': prediction,
'confidence': confidence,
'ground_truth': ground_truth
}
self.prediction_buffer.append(prediction_log)
# Check for data drift
if len(self.prediction_buffer) >= 100:
self.check_drift()
def check_drift(self):
recent_predictions = self.prediction_buffer[-100:]
drift_score = self.drift_detector.detect_drift(recent_predictions)
if drift_score > DRIFT_THRESHOLD:
self.alert_team(f"Data drift detected for {self.model_name}")
self.trigger_retraining_pipeline()
Deployment Strategies
1. Blue-Green Deployment
<h1 id="kubernetes deployment" class="heading-1">Kubernetes deployment</h1>
apiVersion: apps/v1
kind: Deployment
metadata:
name: ml-model-blue
spec:
replicas: 3
selector:
matchLabels:
app: ml-model
version: blue
template:
metadata:
labels:
app: ml-model
version: blue
spec:
containers:
- name: ml-model
image: ml-model:v1.2.0
ports:
- containerPort: 5000
resources:
requests:
memory: "256Mi"
cpu: "250m"
limits:
memory: "512Mi"
cpu: "500m"
2. A/B Testing Models
class ModelABTest:
def __init__(self, model_a, model_b, traffic_split=0.5):
self.model_a = model_a
self.model_b = model_b
self.traffic_split = traffic_split
self.performance_tracker = {'A': [], 'B': []}
def predict(self, input_data):
if random.random() < self.traffic_split:
prediction = self.model_a.predict(input_data)
model_version = 'A'
else:
prediction = self.model_b.predict(input_data)
model_version = 'B'
# Log prediction for analysis
self.performance_tracker[model_version].append({
'input': input_data,
'prediction': prediction,
'timestamp': datetime.now()
})
return prediction, model_version
def evaluate_performance(self, ground_truths):
results = {}
for version in ['A', 'B']:
predictions = [p['prediction'] for p in self.performance_tracker[version]]
accuracy = accuracy_score(ground_truths, predictions)
results[version] = accuracy
return results
Best Practices
1. Data Pipeline Management
Feature Store Implementation
class FeatureStore:
def __init__(self, storage_backend):
self.storage = storage_backend
self.feature_cache = {}
def get_features(self, entity_id: str, feature_names: List[str]):
cache_key = f"{entity_id}:{'_'.join(feature_names)}"
if cache_key in self.feature_cache:
return self.feature_cache[cache_key]
features = {}
for feature_name in feature_names:
features[feature_name] = self.storage.get_feature(
entity_id, feature_name
)
self.feature_cache[cache_key] = features
return features
def update_features(self, entity_id: str, features: dict):
for feature_name, value in features.items():
self.storage.update_feature(entity_id, feature_name, value)
# Invalidate cache
cache_keys_to_remove = [k for k in self.feature_cache.keys()
if k.startswith(entity_id)]
for key in cache_keys_to_remove:
del self.feature_cache[key]
2. Security and Compliance
Model Security
class ModelSecurity:
def __init__(self):
self.access_log = []
self.rate_limiter = RateLimiter()
def authenticate_request(self, api_key: str, request_data: dict):
# Validate API key
if not self.validate_api_key(api_key):
self.log_security_event("Invalid API key", request_data)
return False
# Check rate limits
if not self.rate_limiter.is_allowed(api_key):
self.log_security_event("Rate limit exceeded", request_data)
return False
# Validate input data
if not self.validate_input(request_data):
self.log_security_event("Invalid input data", request_data)
return False
return True
def log_security_event(self, event_type: str, request_data: dict):
security_event = {
'timestamp': datetime.now(),
'event_type': event_type,
'ip_address': request_data.get('ip'),
'user_agent': request_data.get('user_agent'),
'api_key_hash': hash(request_data.get('api_key', ''))
}
self.access_log.append(security_event)
# Send to security monitoring system
self.send_to_security_monitoring(security_event)
Conclusion
Machine learning in production requires a different mindset than model development. Focus on reliability, monitoring, security, and continuous improvement rather than just model accuracy.
Remember that production ML systems are software systems first and ML systems second. Apply software engineering best practices and build robust, maintainable systems that can evolve with your business needs.
